diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs index 456097ff9db9a..bde39d9c6e6cc 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs @@ -4,6 +4,7 @@ namespace Microsoft.ML.OnnxRuntime { using System; + using System.Diagnostics; using System.Runtime.InteropServices; /// @@ -22,7 +23,7 @@ public enum OrtCompileApiFlags : uint /// This class is used to set options for model compilation, and to produce a compiled model using those options. /// See https://onnxruntime.ai/docs/api/c/ for further details of various options. /// - public class OrtModelCompilationOptions : SafeHandle + public class OrtModelCompilationOptions : IDisposable { /// /// Create a new OrtModelCompilationOptions object from SessionOptions. @@ -31,11 +32,10 @@ public class OrtModelCompilationOptions : SafeHandle /// to enable graph optimizations. /// SessionOptions instance to read settings from. public OrtModelCompilationOptions(SessionOptions sessionOptions) - : base(IntPtr.Zero, true) { NativeApiStatus.VerifySuccess( NativeMethods.CompileApi.OrtCreateModelCompilationOptionsFromSessionOptions( - OrtEnv.Instance().Handle, sessionOptions.Handle, out handle)); + OrtEnv.Instance().Handle, sessionOptions.Handle, out _handle)); } /// @@ -43,7 +43,7 @@ public OrtModelCompilationOptions(SessionOptions sessionOptions) /// public void CompileModel() { - NativeApiStatus.VerifySuccess(NativeMethods.CompileApi.OrtCompileModel(OrtEnv.Instance().Handle, handle)); + NativeApiStatus.VerifySuccess(NativeMethods.CompileApi.OrtCompileModel(OrtEnv.Instance().Handle, _handle)); } @@ -55,7 +55,7 @@ public void SetInputModelPath(string path) { var platformPath = NativeOnnxValueHelper.GetPlatformSerializedString(path); NativeApiStatus.VerifySuccess( - NativeMethods.CompileApi.OrtModelCompilationOptions_SetInputModelPath(handle, platformPath)); + NativeMethods.CompileApi.OrtModelCompilationOptions_SetInputModelPath(_handle, platformPath)); } /// @@ -67,7 +67,7 @@ public void SetInputModelFromBuffer(byte[] buffer) { NativeApiStatus.VerifySuccess( NativeMethods.CompileApi.OrtModelCompilationOptions_SetInputModelFromBuffer( - handle, buffer, (UIntPtr)buffer.Length)); + _handle, buffer, (UIntPtr)buffer.Length)); } /// @@ -78,7 +78,7 @@ public void SetOutputModelPath(string path) { var platformPath = NativeOnnxValueHelper.GetPlatformSerializedString(path); NativeApiStatus.VerifySuccess( - NativeMethods.CompileApi.OrtModelCompilationOptions_SetOutputModelPath(handle, platformPath)); + NativeMethods.CompileApi.OrtModelCompilationOptions_SetOutputModelPath(_handle, platformPath)); } @@ -93,7 +93,7 @@ public void SetOutputModelExternalInitializersFile(string filePath, ulong thresh var platformPath = NativeOnnxValueHelper.GetPlatformSerializedString(filePath); NativeApiStatus.VerifySuccess( NativeMethods.CompileApi.OrtModelCompilationOptions_SetOutputModelExternalInitializersFile( - handle, platformPath, new UIntPtr(threshold))); + _handle, platformPath, new UIntPtr(threshold))); } // TODO: In order to use this to create an InferenceSession without copying bytes we need more infrastructure. @@ -108,7 +108,7 @@ internal void SetOutputModelBuffer(OrtAllocator allocator, { NativeApiStatus.VerifySuccess( NativeMethods.CompileApi.OrtModelCompilationOptions_SetOutputModelBuffer( - handle, allocator.Pointer, ref outputModelBufferPtr, ref outputModelBufferSizePtr)); + _handle, allocator.Pointer, ref outputModelBufferPtr, ref outputModelBufferSizePtr)); } /// @@ -119,7 +119,7 @@ internal void SetOutputModelBuffer(OrtAllocator allocator, public void SetEpContextEmbedMode(bool embed) { NativeApiStatus.VerifySuccess( - NativeMethods.CompileApi.OrtModelCompilationOptions_SetEpContextEmbedMode(handle, embed)); + NativeMethods.CompileApi.OrtModelCompilationOptions_SetEpContextEmbedMode(_handle, embed)); } /// @@ -129,7 +129,7 @@ public void SetEpContextEmbedMode(bool embed) public void SetFlags(OrtCompileApiFlags flags) { NativeApiStatus.VerifySuccess( - NativeMethods.CompileApi.OrtModelCompilationOptions_SetFlags(handle, (uint)flags)); + NativeMethods.CompileApi.OrtModelCompilationOptions_SetFlags(_handle, (uint)flags)); } /// @@ -145,7 +145,7 @@ public void SetEpContextBinaryInformation(string outputDirectory, string modelNa var platformModelName = NativeOnnxValueHelper.GetPlatformSerializedString(modelName); NativeApiStatus.VerifySuccess( NativeMethods.CompileApi.OrtModelCompilationOptions_SetEpContextBinaryInformation( - handle, platformOutputDirectory, platformModelName)); + _handle, platformOutputDirectory, platformModelName)); } /// @@ -156,26 +156,352 @@ public void SetGraphOptimizationLevel(GraphOptimizationLevel graphOptimizationLe { NativeApiStatus.VerifySuccess( NativeMethods.CompileApi.OrtModelCompilationOptions_SetGraphOptimizationLevel( - handle, graphOptimizationLevel)); + _handle, graphOptimizationLevel)); } - internal IntPtr Handle => handle; + /// + /// Delegate to write/save a buffer containing ONNX model bytes to a custom destination. The delegate + /// may be called repeatedly until the entire output model has been written out. Each call to the delegate + /// is expected to consume the entire buffer. + /// + /// The buffer to write out. + /// + public delegate void WriteBufferToDestinationDelegate(ReadOnlySpan buffer); + + /// + /// Sets a delegate that is called by ORT to write out the output model's serialized ONNX bytes. + /// The provided delegate may be called repeatedly until the entire output model has been written out. + /// Each call to the delegate is expected to consume/handle the entire input buffer. + /// + /// The delegate called by ORT to write out the model. + public void SetOutputModelWriteDelegate(WriteBufferToDestinationDelegate writeBufferDelegate) + { + _writeBufferToDestinationDelegateState?.Dispose(); + _writeBufferToDestinationDelegateState = + new DelegateResources( + new WriteBufferToDestinationConnector(writeBufferDelegate), + new NativeMethods.DOrtWriteBufferToDestinationDelegate( + WriteBufferToDestinationConnector.WriteBufferToDestinationDelegateWrapper)); + + IntPtr funcPtr = _writeBufferToDestinationDelegateState.GetFunctionPointerForDelegate(); + + NativeApiStatus.VerifySuccess( + NativeMethods.CompileApi.OrtModelCompilationOptions_SetOutputModelWriteFunc( + _handle, + funcPtr, + _writeBufferToDestinationDelegateState.GetConnectorHandleAsPointer())); + } + + /// + /// Delegate called by ORT for every initializer when generating the compiled model. + /// The delegate allows the user to determine whether the initializer should be stored within the compiled + /// model or externally in a file. If the delegate chooses to store an initializer externally, the delegate + /// implementation is responsible for writing the initializer data to a file. + /// + /// The initializer's name. + /// The readonly OrtValue instance containing the data, type, and + /// shape of the initializer. + /// May be null. If the initializer is originally stored externally, + /// this contains the file path, file offset, and data size. Otherwise, this is null. + /// A new OrtExternalInitializerInfo indicating the new location of the initializer. + /// Returns null if the initializer should be stored within the generated compiled model. + /// The return value may be null. + /// + public delegate OrtExternalInitializerInfo GetInitializerLocationDelegate( + string initializerName, + IReadOnlyOrtValue initializerValue, + IReadOnlyExternalInitializerInfo originalInitializerLocation); + + /// + /// Sets a delegate that is called by ORT for every initializer when generating the compiled model. + /// The delegate allows the user to determine whether the initializer should be stored within the compiled + /// model or externally in a file. If the delegate chooses to store an initializer externally, the delegate + /// implementation is responsible for writing the initializer data to a file. + /// + /// The delegate called by ORT for every initializer. + public void SetOutputModelGetInitializerLocationDelegate( + GetInitializerLocationDelegate getInitializerLocationDelegate) + { + _getInitializerLocationDelegateState?.Dispose(); + _getInitializerLocationDelegateState = + new DelegateResources( + new GetInitializerLocationConnector(getInitializerLocationDelegate), + new NativeMethods.DOrtGetInitializerLocationDelegate( + GetInitializerLocationConnector.GetInitializerLocationDelegateWrapper)); + + IntPtr funcPtr = _getInitializerLocationDelegateState.GetFunctionPointerForDelegate(); + + NativeApiStatus.VerifySuccess( + NativeMethods.CompileApi.OrtModelCompilationOptions_SetOutputModelGetInitializerLocationFunc( + _handle, + funcPtr, + _getInitializerLocationDelegateState.GetConnectorHandleAsPointer())); + } + + #region Delegate helpers + /// + /// Class to bridge the C# and native worlds for the "write buffer to destination" delegate + /// + private class WriteBufferToDestinationConnector + { + private readonly WriteBufferToDestinationDelegate _userDelegate; + + internal WriteBufferToDestinationConnector(WriteBufferToDestinationDelegate writeBufferDelegate) + { + _userDelegate = writeBufferDelegate; + } + + public static IntPtr WriteBufferToDestinationDelegateWrapper(IntPtr /* void* */ state, + IntPtr /* const void* */ buffer, + UIntPtr /* size_t */ bufferNumBytes) + { + try + { + + WriteBufferToDestinationConnector connector = (WriteBufferToDestinationConnector) + GCHandle.FromIntPtr(state).Target; + ReadOnlySpan bufferSpan; + + unsafe + { + // NOTE: A Span can only view 2GB of data. This is fine because ORT does not write out + // chunks that large. However, if we ever need to, the solution is to just write a loop here + // that repeatedly calls the delegate with smaller chunks of data. + bufferSpan = new ReadOnlySpan(buffer.ToPointer(), checked((int)bufferNumBytes)); + } + + connector._userDelegate(bufferSpan); + } + catch (Exception ex) + { + var error = $"The C# WriteBufferToDestination delegate threw an exception: {ex.Message}"; + IntPtr status = NativeMethods.OrtCreateStatus((uint)ErrorCode.Fail, + NativeOnnxValueHelper.StringToZeroTerminatedUtf8(error)); + return status; + } + + return IntPtr.Zero; + } + } + + /// + /// Class to bridge the C# and native worlds for the "get initializer location" delegate + /// + private class GetInitializerLocationConnector + { + private readonly GetInitializerLocationDelegate _userDelegate; + + internal GetInitializerLocationConnector(GetInitializerLocationDelegate getInitializerLocationDelegate) + { + _userDelegate = getInitializerLocationDelegate; + } + + public static IntPtr GetInitializerLocationDelegateWrapper( + IntPtr /* void* */ state, + IntPtr /* const char* */ initializerName, + IntPtr /* const OrtValue* */ initializerValue, + IntPtr /* const OrtExternalInitializerInfo* */ originalInitializerLocation, + out IntPtr /* OrtExternalInitializerInfo** */ newInitializerLocationOutput) + { + newInitializerLocationOutput = IntPtr.Zero; + + try + { + + GetInitializerLocationConnector connector = (GetInitializerLocationConnector)GCHandle. + FromIntPtr(state).Target; + string utf8InitializerName = NativeOnnxValueHelper.StringFromNativeUtf8(initializerName); + IReadOnlyOrtValue readOnlyInitializerValue = new OrtValue(initializerValue, owned: false); + IReadOnlyExternalInitializerInfo readOnlyOriginalInitializerLocation = null; + + if (originalInitializerLocation != IntPtr.Zero) + { + readOnlyOriginalInitializerLocation = new OrtExternalInitializerInfo( + originalInitializerLocation, ownsHandle: false); + } + // Call user's delegate, which may return the new location of the initializer. + OrtExternalInitializerInfo newInitializerLocation = connector._userDelegate( + utf8InitializerName, readOnlyInitializerValue, readOnlyOriginalInitializerLocation); + + if (newInitializerLocation != null) + { + // Delegate returned info about a new location for the initializer. + // Can't guarantee that the new external info returned by user's delegate is not referenced + // by other C# code. ORT expects to own the new external info, so create a copy here and + // give it to ORT. + string newFilePath = newInitializerLocation.GetFilePath(); + byte[] newFilePathBytes = NativeOnnxValueHelper.GetPlatformSerializedString(newFilePath); + + IntPtr status = NativeMethods.OrtCreateExternalInitializerInfo( + newFilePathBytes, + newInitializerLocation.GetFileOffset(), + (UIntPtr)newInitializerLocation.GetByteSize(), + out newInitializerLocationOutput); + + if (status != IntPtr.Zero) + { + return status; + } + } + else + { + // User's delegate did not return a new location for the initializer. ORT will store initializer + // within the generated compiled model. + newInitializerLocationOutput = IntPtr.Zero; + } + } + catch (Exception ex) + { + var error = $"The C# GetInitializerLocation delegate threw an exception: {ex.Message}"; + IntPtr status = NativeMethods.OrtCreateStatus((uint)ErrorCode.Fail, + NativeOnnxValueHelper.StringToZeroTerminatedUtf8(error)); + return status; + } + + return IntPtr.Zero; + } + } /// - /// Indicates whether the native handle is invalid. + /// Disposable class that stores resources for a delegate provided by the user. /// - public override bool IsInvalid => handle == IntPtr.Zero; + /// The type of the connector class + /// (e.g., WriteBufferToDestinationConnector) + /// The type of the native delegate. + private class DelegateResources : IDisposable + where Connector : class + where Delegate : class + { + public DelegateResources(Connector connector, Delegate @delegate) + { + _connector = connector; + _delegate = @delegate; + _connectorHandle = GCHandle.Alloc(_connector); + _delegateHandle = GCHandle.Alloc(_delegate); + } + + internal IntPtr GetFunctionPointerForDelegate() + { + return Marshal.GetFunctionPointerForDelegate(_delegate); + } + + internal IntPtr GetConnectorHandleAsPointer() + { + return GCHandle.ToIntPtr(_connectorHandle); + } + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) + { + if (_disposed) + { + return; + } + + if (disposing) + { + // Dispose other children disposables. We have none. + } + if (_connectorHandle.IsAllocated) + { + _connectorHandle.Free(); + _connector = null; + } + + if (_delegateHandle.IsAllocated) + { + _delegateHandle.Free(); + _delegate = null; + } + + _disposed = true; + } + + ~DelegateResources() + { + Dispose(false); + } + + private Connector _connector = null; + private Delegate _delegate = null; + private GCHandle _connectorHandle = default; + private GCHandle _delegateHandle = default; + private bool _disposed = false; + } + #endregion + + #region IDispose implementation /// - /// Release the native instance of OrtModelCompilationOptions. + /// IDispose implementation. /// - /// true - protected override bool ReleaseHandle() + public void Dispose() { - NativeMethods.CompileApi.OrtReleaseModelCompilationOptions(handle); - handle = IntPtr.Zero; - return true; + Dispose(true); + GC.SuppressFinalize(this); } + + /// + /// IDispose implementation + /// + /// True if Dispose() has been called by the user-side code. False if + /// called by the runtime from inside the finalizer. + protected virtual void Dispose(bool disposing) + { + if (_disposed) + { + return; + } + + if (disposing) + { + _writeBufferToDestinationDelegateState?.Dispose(); + _getInitializerLocationDelegateState?.Dispose(); + } + + Debug.Assert(_handle != IntPtr.Zero); + NativeMethods.CompileApi.OrtReleaseModelCompilationOptions(_handle); + _handle = IntPtr.Zero; + _disposed = true; + } + + /// + /// Finalizer that releases the native handle if not already released by Dispose(). + /// + ~OrtModelCompilationOptions() + { + Dispose(false); + } + #endregion + + /// + /// Handle to the native OrtModelCompilationOptions object. + /// + private IntPtr _handle; + + /// + /// True if this OrtModelCompilationOptions instance has already been disposed. + /// + private bool _disposed = false; + + /// + /// Stores delegate state for the "write buffer to destination" delegate. + /// + private DelegateResources + _writeBufferToDestinationDelegateState = null; + + /// + /// Stores delegate state for the "get initializer location" delegate. + /// + private DelegateResources + _getInitializerLocationDelegateState = null; } -} \ No newline at end of file +} diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs index 9d25d96bdaa5a..84020d84c9e73 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs @@ -23,6 +23,8 @@ public struct OrtCompileApi public IntPtr ModelCompilationOptions_SetFlags; public IntPtr ModelCompilationOptions_SetEpContextBinaryInformation; public IntPtr ModelCompilationOptions_SetGraphOptimizationLevel; + public IntPtr ModelCompilationOptions_SetOutputModelWriteFunc; + public IntPtr ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc; } internal class NativeMethods @@ -118,6 +120,22 @@ public DOrtModelCompilationOptions_SetEpContextBinaryInformation public DOrtModelCompilationOptions_SetGraphOptimizationLevel OrtModelCompilationOptions_SetGraphOptimizationLevel; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetOutputModelWriteFunc( + IntPtr /* OrtModelCompilationOptions* */ options, + IntPtr /* DOrtWriteBufferDelegate */ writeFunc, + IntPtr /* void* */ state); + public DOrtModelCompilationOptions_SetOutputModelWriteFunc + OrtModelCompilationOptions_SetOutputModelWriteFunc; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetOutputModelGetInitializerLocationFunc( + IntPtr /* OrtModelCompilationOptions* */ options, + IntPtr /* DOrtHandleInitializerDataDelegate */ handleInitializerFunc, + IntPtr /* void* */ state); + public DOrtModelCompilationOptions_SetOutputModelGetInitializerLocationFunc + OrtModelCompilationOptions_SetOutputModelGetInitializerLocationFunc; + internal NativeMethods(OnnxRuntime.NativeMethods.DOrtGetCompileApi getCompileApi) { @@ -188,6 +206,17 @@ internal NativeMethods(OnnxRuntime.NativeMethods.DOrtGetCompileApi getCompileApi _compileApi.ModelCompilationOptions_SetGraphOptimizationLevel, typeof(DOrtModelCompilationOptions_SetGraphOptimizationLevel)); + OrtModelCompilationOptions_SetOutputModelWriteFunc = + (DOrtModelCompilationOptions_SetOutputModelWriteFunc)Marshal.GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetOutputModelWriteFunc, + typeof(DOrtModelCompilationOptions_SetOutputModelWriteFunc)); + + OrtModelCompilationOptions_SetOutputModelGetInitializerLocationFunc = + (DOrtModelCompilationOptions_SetOutputModelGetInitializerLocationFunc)Marshal. + GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc, + typeof(DOrtModelCompilationOptions_SetOutputModelGetInitializerLocationFunc)); + } } } diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index 3c92400715740..53880308da261 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -450,6 +450,7 @@ public struct OrtApi public IntPtr Graph_GetModelMetadata; public IntPtr GetModelCompatibilityForEpDevices; + public IntPtr CreateExternalInitializerInfo; } internal static class NativeMethods @@ -787,9 +788,35 @@ static NativeMethods() api_.SessionOptionsSetEpSelectionPolicyDelegate, typeof(DSessionOptionsSetEpSelectionPolicyDelegate)); + OrtReleaseExternalInitializerInfo = + (DOrtReleaseExternalInitializerInfo)Marshal.GetDelegateForFunctionPointer( + api_.ReleaseExternalInitializerInfo, + typeof(DOrtReleaseExternalInitializerInfo)); + + OrtExternalInitializerInfo_GetFilePath = + (DOrtExternalInitializerInfo_GetFilePath)Marshal.GetDelegateForFunctionPointer( + api_.ExternalInitializerInfo_GetFilePath, + typeof(DOrtExternalInitializerInfo_GetFilePath)); + + OrtExternalInitializerInfo_GetFileOffset = + (DOrtExternalInitializerInfo_GetFileOffset)Marshal.GetDelegateForFunctionPointer( + api_.ExternalInitializerInfo_GetFileOffset, + typeof(DOrtExternalInitializerInfo_GetFileOffset)); + + OrtExternalInitializerInfo_GetByteSize = + (DOrtExternalInitializerInfo_GetByteSize)Marshal.GetDelegateForFunctionPointer( + api_.ExternalInitializerInfo_GetByteSize, + typeof(DOrtExternalInitializerInfo_GetByteSize)); + OrtGetModelCompatibilityForEpDevices = (DOrtGetModelCompatibilityForEpDevices)Marshal.GetDelegateForFunctionPointer( api_.GetModelCompatibilityForEpDevices, typeof(DOrtGetModelCompatibilityForEpDevices)); + + OrtCreateExternalInitializerInfo = + (DOrtCreateExternalInitializerInfo)Marshal.GetDelegateForFunctionPointer( + api_.CreateExternalInitializerInfo, + typeof(DOrtCreateExternalInitializerInfo)); + } internal class NativeLib @@ -2382,6 +2409,70 @@ out IntPtr lora_adapter public delegate ref CompileApi.OrtCompileApi DOrtGetCompileApi(); #endif public static DOrtGetCompileApi OrtGetCompileApi; + + /// + /// Delegate called by ORT to write a buffer (ONNX model bytes) to a custom destination (e.g., file or stream). + /// + /// State that was provided in when the delegate was registered. + /// The buffer to write. + /// The size of the buffer in bytes. + /// OrtStatus* + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtWriteBufferToDestinationDelegate( + IntPtr /* void* */ state, + IntPtr /* const void* */ buffer, + UIntPtr /* size_t */ bufferNumBytes + ); + + /// + /// Function called by ORT to allow user to specify how an initializer should be saved while compiling + /// a model, that is, either written to an external file or stored within the model. ORT calls this function + /// for every initializer. + /// + /// State that was provided when the delegate was registered. + /// The initializer's name. + /// The OrtValue containing the initializer's data, type, and shape + /// The original initializer's location in an external file, or NULL. + /// Output parameter set to a new OrtExternalInitializerInfo instance + /// indicating the location where the function implementation stored the initializer data. If the function + /// implementation sets `newExternalInfo` to NULL, ORT stores the initializer within the generated model. + /// + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtGetInitializerLocationDelegate( + IntPtr /* void* */ state, + IntPtr /* const char* */ initializerName, + IntPtr /* const OrtValue* */ initializerValue, + IntPtr /* const OrtExternalInitializerInfo* */ externalInfo, + out IntPtr /* OrtExternalInitializerInfo** */ newExternalInfo + ); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate void DOrtReleaseExternalInitializerInfo(IntPtr /* OrtExternalInitializerInfo* */ info); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtCreateExternalInitializerInfo( + byte[] /* const ORTCHAR_T* */ filePath, + long /* int64_t */ fileOffset, + UIntPtr /* size_t */ byteSize, + out IntPtr /* OrtExternalInitializerInfo** */ outInfo); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* const ORTCHAR_T* */ DOrtExternalInitializerInfo_GetFilePath( + IntPtr /* const OrtExternalInitializerInfo* */ info); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate long /* int64_t */ DOrtExternalInitializerInfo_GetFileOffset( + IntPtr /* const OrtExternalInitializerInfo* */ info); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate UIntPtr /* size_t */ DOrtExternalInitializerInfo_GetByteSize( + IntPtr /* const OrtExternalInitializerInfo* */ info); + + public static DOrtReleaseExternalInitializerInfo OrtReleaseExternalInitializerInfo; + public static DOrtCreateExternalInitializerInfo OrtCreateExternalInitializerInfo; + public static DOrtExternalInitializerInfo_GetFilePath OrtExternalInitializerInfo_GetFilePath; + public static DOrtExternalInitializerInfo_GetFileOffset OrtExternalInitializerInfo_GetFileOffset; + public static DOrtExternalInitializerInfo_GetByteSize OrtExternalInitializerInfo_GetByteSize; #endregion #region Auto EP API related diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs index fc14be00ee47b..4611428ea12ef 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs @@ -150,6 +150,45 @@ internal static byte[] GetPlatformSerializedString(string str) else return StringToZeroTerminatedUtf8(str); } + + /// + /// Converts a null-terminated path string that is pointed to by the given IntPtr handle into + /// a C# UTF-16 string. + /// + /// A path string on Windows is utf-16, but utf-8 on other operating systems. + /// + /// + internal static string StringFromNativePathString(IntPtr strPtr) + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + if (strPtr == IntPtr.Zero) + { + return string.Empty; + } + + // Get length of utf16 string by checking for two 0 bytes in a row. + int length = 0; + while (Marshal.ReadInt16(strPtr, length * 2) != 0) + { + length += 1; + } + + if (length == 0) + { + return string.Empty; + } + + unsafe + { + return System.Text.Encoding.Unicode.GetString((byte*)strPtr, length * 2); + } + } + else + { + return StringFromNativeUtf8(strPtr); + } + } } // Guards an array of disposable objects on stack and disposes them in reverse order diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtExternalInitializerInfo.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtExternalInitializerInfo.shared.cs new file mode 100644 index 0000000000000..aca16e939ce21 --- /dev/null +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtExternalInitializerInfo.shared.cs @@ -0,0 +1,136 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + + +namespace Microsoft.ML.OnnxRuntime +{ + using System; + using System.Diagnostics; + using System.Runtime.InteropServices; + + /// + /// Class to that stores information about the file location where an "external" initializer is stored. + /// + /// + public class OrtExternalInitializerInfo : SafeHandle, IReadOnlyExternalInitializerInfo + { + // Set to false when constructed with an externally managed constant handle owned by ORT. + private readonly bool _ownsHandle = true; + + /// + /// Create a new OrtExternalInitializerInfo instance. + /// + /// The path to the file that stores the initializer data. + /// The byte offset in the file where the data is stored. + /// The size of the data (in bytes) within the file. + public OrtExternalInitializerInfo(string filePath, long fileOffset, long byteSize) + : base(IntPtr.Zero, ownsHandle: true) + { + var platformFilePath = NativeOnnxValueHelper.GetPlatformSerializedString(filePath); + NativeApiStatus.VerifySuccess( + NativeMethods.OrtCreateExternalInitializerInfo(platformFilePath, fileOffset, (UIntPtr)byteSize, out handle)); + _ownsHandle = true; + } + + /// + /// Create a new OrtExternalInitializerInfo instance from an existing native OrtExternalInitializerInfo handle. + /// + /// Native OrtExternalInitializerInfo handle. + /// True if the OrtExternalInitializerInfo instance owns the native handle. + /// Defaults to false. + internal OrtExternalInitializerInfo(IntPtr constHandle, bool ownsHandle = false) + : base(IntPtr.Zero, ownsHandle) + { + Debug.Assert(constHandle != IntPtr.Zero); + SetHandle(constHandle); + _ownsHandle = ownsHandle; + } + + /// + /// Get the file path to the file that store's the initializer's data. + /// + /// + /// The path is relative to the filesystem directory where the ONNX model was stored. + /// + /// The file path. + public string GetFilePath() + { + IntPtr filePathPtr = NativeMethods.OrtExternalInitializerInfo_GetFilePath(handle); + if (filePathPtr == IntPtr.Zero) + { + return string.Empty; + } + + return NativeOnnxValueHelper.StringFromNativePathString(filePathPtr); + } + + /// + /// Get the byte offset within the file where the initializer's data is stored. + /// + /// The file offset location. + public long GetFileOffset() + { + return NativeMethods.OrtExternalInitializerInfo_GetFileOffset(handle); + } + + /// + /// Get the size in bytes of the initializer's data within the file. + /// + /// The size in bytes of the initializer data. + public long GetByteSize() + { + UIntPtr byteSize = NativeMethods.OrtExternalInitializerInfo_GetByteSize(handle); + return checked((long)byteSize); + } + + /// + /// Indicates whether the native handle is invalid. + /// + public override bool IsInvalid { get { return handle == IntPtr.Zero; } } + + /// + /// Release the native instance of OrtExternalInitializerInfo if we own it. + /// + /// true on success and false on error. + protected override bool ReleaseHandle() + { + if (!_ownsHandle) + { + // Return false to indicate an error. + // ReleaseHandle() should not be called on a const handle that this class does not own. + return false; + } + + NativeMethods.OrtReleaseExternalInitializerInfo(handle); + handle = IntPtr.Zero; + return true; + } + } + + /// + /// Interface for all readonly methods implemented by OrtExternalInitializerInfo. + /// + public interface IReadOnlyExternalInitializerInfo + { + /// + /// Get the file path to the file that store's the initializer's data. + /// + /// + /// The path is relative to the filesystem directory where the ONNX model was stored. + /// + /// The file path. + string GetFilePath(); + + /// + /// Get the byte offset within the file where the initializer's data is stored. + /// + /// The file offset location. + long GetFileOffset(); + + /// + /// Get the size in bytes of the initializer's data within the file. + /// + /// The size in bytes of the initializer data. + long GetByteSize(); + } +} \ No newline at end of file diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs index 01ee3aa5ae753..d848c63450ec1 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs @@ -33,6 +33,147 @@ public enum OnnxValueType ONNX_TYPE_OPTIONAL = 6, // It's an optional type that designates anything above (except UNKNOWN) } + /// + /// Interface for all readonly methods implemented by OrtValue. + /// + public interface IReadOnlyOrtValue + { + /// + /// Get the ONNX value type for the OrtValue (e.g., OnnxValueType.ONNX_TYPE_TENSOR). + /// + /// OnnxValueType + OnnxValueType OnnxType { get; } + + /// + /// Returns true if OrtValue contains a tensor + /// + /// true if tensor + bool IsTensor { get; } + + /// + /// Returns true if OrtValue contains a sparse tensor + /// + /// true if sparse tensor + bool IsSparseTensor { get; } + + /// + /// Returns type information about the contained OnnxValue. + /// + /// a disposable instance of OrtTypeInfo + OrtTypeInfo GetTypeInfo(); + + /// + /// Obtains Tensor And Type Information from the OrtValue iff it contains a tensor. + /// Valid only for OrtValues that contain a tensor. + /// + /// A disposable instance of OrtTensorTypeAndShapeInfo + OrtTensorTypeAndShapeInfo GetTensorTypeAndShape(); + + /// + /// Returns the size of the tensor data in bytes. + /// + /// size of the tensor data in bytes + long GetTensorSizeInBytes(); + + /// + /// Returns OrtMemoryInfo iff this OrtValue contains a tensor or a sparse tensor. + /// + /// OrtMemoryInfo that describes the underlying memory allocation + /// + OrtMemoryInfo GetTensorMemoryInfo(); + + /// + /// Returns a ReadOnlySpan over tensor native buffer that + /// provides a read-only view. + /// + /// Note, that the memory may be device allocated and, therefore, not accessible from the CPU. + /// To get memory descriptor use GetTensorMemoryInfo(). + /// + /// OrtValue must contain a non-string tensor. + /// The span is valid as long as the OrtValue instance is alive (not disposed). + /// + /// + /// ReadOnlySpan + /// + ReadOnlySpan GetTensorDataAsSpan() where T : unmanaged; + +#if NET8_0_OR_GREATER + /// + /// Returns a ReadOnlyTensorSpan over tensor native buffer that + /// provides a read-only view. + /// + /// Note, that the memory may be device allocated and, therefore, not accessible from the CPU. + /// To get memory descriptor use GetTensorMemoryInfo(). + /// + /// OrtValue must contain a non-string tensor. + /// The span is valid as long as the OrtValue instance is alive (not disposed). + /// + /// + /// ReadOnlySpan + /// + [Experimental("SYSLIB5001")] + SystemNumericsTensors.ReadOnlyTensorSpan GetTensorDataAsTensorSpan() where T : unmanaged; +#endif + + /// + /// Valid for composite ML types like map, sequence. + /// Returns 2 for map (keys, values) and N for sequence, where N is the number of elements + /// in the sequence. + /// + /// Element count + int GetValueCount(); + + /// + /// For non tensors return OrtValue element at the specified index. + /// For maps only indices 0 and 1 are valid. For sequences, [0..N) are valid. + /// See GetValueCount() to determine the valid range. + /// + /// + /// allocator to use + /// OrtValue disposable instance that points to the corresponding element of the composite type + OrtValue GetValue(int index, OrtAllocator allocator); + + /// + /// Fetch string tensor element buffer pointer at the specified index, + /// convert/copy to UTF-16 char[] and return a ReadOnlyMemory{char} instance. + /// + /// Obtain TensorTypeAndShape to get shape and element count. + /// + /// flat string tensor element index + /// ReadOnlyMemory{char} backed by a managed char[]. Its lifespan is not + /// tied to the native buffer of OrtValue. + ReadOnlyMemory GetStringElementAsMemory(int index); + + /// + /// Fetch string tensor element buffer pointer at the specified index, + /// copy/convert UTF-8 into a UTF-16 string and return it. + /// + /// Obtain TensorTypeAndShape to get shape and element count. + /// + /// flat string tensor element index + /// UTF-16 string instance + string GetStringElement(int index); + + /// + /// Get a span over the native memory of the string tensor element. + /// The span is valid as long as the OrtValue is valid. + /// + /// This is useful if you want to perform your own UTF-8 decoding or + /// you do not care about decoding. + /// Obtain TensorTypeAndShape to get shape and element count. + /// + /// flat element index + /// ReadOnlySpan over UTF-8 bytes of the string tensor element + ReadOnlySpan GetStringElementAsSpan(int index); + + /// + /// Convenience method to obtain all string tensor elements as a string array. + /// + /// string[] + /// + string[] GetStringTensorAsArray(); + } + /// /// Represents a disposable OrtValue. /// This class exposes a native instance of OrtValue. @@ -44,7 +185,7 @@ public enum OnnxValueType /// disposed properly, the pinned memory will continue to be pinned and interfere /// with GC operation. /// - public class OrtValue : IOrtValueOwner, IDisposable + public class OrtValue : IOrtValueOwner, IDisposable, IReadOnlyOrtValue { // OrtValues that are members of Sequences or Maps that map. They potentially map managed memory and we need to keep them around. // this exists only when we deal with compose ML types. @@ -52,11 +193,20 @@ public class OrtValue : IOrtValueOwner, IDisposable private IntPtr _handle; private MemoryHandle? _memHandle; // Present when the OrtValue is created on top of managed memory private bool _disposed; + private bool _owned = true; - internal OrtValue(IntPtr handle) + /// + /// Constructs OrtValue from a native handle. If `owned` is true, the OrtValue instance takes + /// ownership of the native handle and disposes it when the OrtValue instance is disposed. + /// + /// The native OrtValue handle. + /// True if this class instance owns the handle. If false, the handle + /// will not be released. Defaults to true. + internal OrtValue(IntPtr handle, bool owned = true) { _handle = handle; InitOnnxType(); + _owned = owned; } /// @@ -1464,7 +1614,10 @@ protected virtual void Dispose(bool disposing) } Debug.Assert(_handle != IntPtr.Zero); - NativeMethods.OrtReleaseValue(_handle); + if (_owned) + { + NativeMethods.OrtReleaseValue(_handle); + } _handle = IntPtr.Zero; _disposed = true; } diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs index f1eef57e03ea5..fe2cab57658c8 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs @@ -21,104 +21,249 @@ public class CompileApiTests [Fact] public void BasicUsage() { - var so = new SessionOptions(); - using (var compileOptions = new OrtModelCompilationOptions(so)) + using (var sessionOptions = new SessionOptions()) { - // mainly checking these don't throw which ensures all the plumbing for the binding works. - compileOptions.SetInputModelPath("model.onnx"); - compileOptions.SetOutputModelPath("compiled_model.onnx"); + using (var compileOptions = new OrtModelCompilationOptions(sessionOptions)) + { + // mainly checking these don't throw which ensures all the plumbing for the binding works. + compileOptions.SetInputModelPath("model.onnx"); + compileOptions.SetOutputModelPath("compiled_model.onnx"); - compileOptions.SetOutputModelExternalInitializersFile("external_data.bin", 512); - compileOptions.SetEpContextEmbedMode(true); - compileOptions.SetGraphOptimizationLevel(GraphOptimizationLevel.ORT_ENABLE_BASIC); + compileOptions.SetOutputModelExternalInitializersFile("external_data.bin", 512); + compileOptions.SetEpContextEmbedMode(true); + compileOptions.SetGraphOptimizationLevel(GraphOptimizationLevel.ORT_ENABLE_BASIC); - } + } - // setup a new instance as SetOutputModelExternalInitializersFile is incompatible with SetOutputModelBuffer - using (var compileOptions = new OrtModelCompilationOptions(so)) - { - var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); - compileOptions.SetInputModelFromBuffer(model); + // setup a new instance as SetOutputModelExternalInitializersFile is incompatible with SetOutputModelBuffer + using (var compileOptions = new OrtModelCompilationOptions(sessionOptions)) + { + var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); + compileOptions.SetInputModelFromBuffer(model); - // SetOutputModelBuffer updates the user provided IntPtr and size when it allocates data post-compile. - // Due to that we need to allocate an IntPtr and UIntPtr here. - IntPtr bytePtr = new IntPtr(); - UIntPtr bytesSize = new UIntPtr(); - var allocator = OrtAllocator.DefaultInstance; - compileOptions.SetOutputModelBuffer(allocator, ref bytePtr, ref bytesSize); - compileOptions.SetEpContextBinaryInformation("./", "squeezenet.onnx"); + // SetOutputModelBuffer updates the user provided IntPtr and size when it allocates data post-compile. + // Due to that we need to allocate an IntPtr and UIntPtr here. + IntPtr bytePtr = new IntPtr(); + UIntPtr bytesSize = new UIntPtr(); + var allocator = OrtAllocator.DefaultInstance; + compileOptions.SetOutputModelBuffer(allocator, ref bytePtr, ref bytesSize); + compileOptions.SetEpContextBinaryInformation("./", "squeezenet.onnx"); - compileOptions.CompileModel(); + compileOptions.CompileModel(); - Assert.NotEqual(IntPtr.Zero, bytePtr); - Assert.NotEqual(UIntPtr.Zero, bytesSize); + Assert.NotEqual(IntPtr.Zero, bytePtr); + Assert.NotEqual(UIntPtr.Zero, bytesSize); - byte[] compiledBytes = new byte[bytesSize.ToUInt64()]; - Marshal.Copy(bytePtr, compiledBytes, 0, (int)bytesSize.ToUInt32()); + byte[] compiledBytes = new byte[bytesSize.ToUInt64()]; + Marshal.Copy(bytePtr, compiledBytes, 0, (int)bytesSize.ToUInt32()); - // Check the compiled model is valid - using (var session = new InferenceSession(compiledBytes, so)) - { - Assert.NotNull(session); + // Check the compiled model is valid + using (var session = new InferenceSession(compiledBytes, sessionOptions)) + { + Assert.NotNull(session); + } + + allocator.FreeMemory(bytePtr); } - allocator.FreeMemory(bytePtr); - } + // Test using OrtCompileApiFlags.ERROR_NO_NODES_COMPILED. A model compiled with CPU EP will not generate + // any compiled EPContext nodes, so expect an ORT_FAIL error. + using (var compileOptions = new OrtModelCompilationOptions(sessionOptions)) + { + var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); + var output_model_file = "should_not_generate.onnx"; + compileOptions.SetInputModelFromBuffer(model); + compileOptions.SetOutputModelPath(output_model_file); + compileOptions.SetFlags(OrtCompileApiFlags.ERROR_IF_NO_NODES_COMPILED); - // Test using OrtCompileApiFlags.ERROR_NO_NODES_COMPILED. A model compiled with CPU EP will not generate - // any compiled EPContext nodes, so expect an ORT_FAIL error. - using (var compileOptions = new OrtModelCompilationOptions(so)) - { - var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); - var output_model_file = "should_not_generate.onnx"; - compileOptions.SetInputModelFromBuffer(model); - compileOptions.SetOutputModelPath(output_model_file); - compileOptions.SetFlags(OrtCompileApiFlags.ERROR_IF_NO_NODES_COMPILED); + // compile should fail + try + { + compileOptions.CompileModel(); + Assert.Fail("CompileModel() should have thrown an exception"); + } + catch (OnnxRuntimeException ex) + { + Assert.Contains("Unable to compile any nodes", ex.Message); + } - // compile should fail + Assert.False(File.Exists(output_model_file)); // Output file should not be generated. + } + + // Test using OrtCompileApiFlags.ERROR_IF_OUTPUT_FILE_EXISTS. + var outputModelFile = "squeezenet_ctx.onnx"; try { - compileOptions.CompileModel(); - Assert.Fail("CompileModel() should have thrown an exception"); + using (var compileOptions = new OrtModelCompilationOptions(sessionOptions)) + { + var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); + + // Compile and generate an output model. + compileOptions.SetInputModelFromBuffer(model); + compileOptions.SetOutputModelPath(outputModelFile); + compileOptions.CompileModel(); + Assert.True(File.Exists(outputModelFile)); + + // Try to compile again with flag that prevents replacing an existing file. + // Expect failure. + compileOptions.SetFlags(OrtCompileApiFlags.ERROR_IF_OUTPUT_FILE_EXISTS); + + // compile should fail + try + { + compileOptions.CompileModel(); + Assert.Fail("CompileModel() should have thrown an exception"); + } + catch (OnnxRuntimeException ex) + { + Assert.Contains("exists already", ex.Message); + } + } } - catch (OnnxRuntimeException ex) + finally { - Assert.Contains("Unable to compile any nodes", ex.Message); + if (File.Exists(outputModelFile)) + { + // This file is created by ORT, so we delete it manually in finally block. + File.Delete(outputModelFile); + } } - - Assert.False(File.Exists(output_model_file)); // Output file should not be generated. } + } + + [Fact] + public void WriteOutModelWithDelegate() + { + var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); + var outputModelFilePath = "squeezenet_write_delegate_ctx.onnx"; - // Test using OrtCompileApiFlags.ERROR_IF_OUTPUT_FILE_EXISTS. - using (var compileOptions = new OrtModelCompilationOptions(so)) + using (FileStream fs = new FileStream(outputModelFilePath, FileMode.Create, FileAccess.Write, FileShare.None, + 4096, FileOptions.DeleteOnClose)) + using (var sessionOptions = new SessionOptions()) + using (var compileOptions = new OrtModelCompilationOptions(sessionOptions)) { - var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); - var output_model_file = "squeezenet_ctx.onnx"; + void BasicWriteBufferDelegate(ReadOnlySpan buffer) + { + Assert.True(buffer.Length > 0); + fs.Write(buffer.ToArray(), 0, buffer.Length); // Write it out to a file + } // Compile and generate an output model. compileOptions.SetInputModelFromBuffer(model); - compileOptions.SetOutputModelPath(output_model_file); + compileOptions.SetOutputModelWriteDelegate(BasicWriteBufferDelegate); compileOptions.CompileModel(); - Assert.True(File.Exists(output_model_file)); + Assert.True(File.Exists(outputModelFilePath)); + } + } - // Try to compile again with flag that prevents replacing an existing file. - // Expect failure. - compileOptions.SetFlags(OrtCompileApiFlags.ERROR_IF_OUTPUT_FILE_EXISTS); + [Fact] + public void BasicGetInitializerLocationDelegate() + { + var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); + var outputModelFilePath = "squeezenet_handle_initializer_delegate_ctx.onnx"; + var initializersFilePath = "squeezenet_handle_initializer_delegate_ctx.bin"; - // compile should fail - try + try + { + using (FileStream fs = new FileStream(initializersFilePath, FileMode.Create, FileAccess.Write, + FileShare.None, 4096, FileOptions.DeleteOnClose)) + using (var sessionOptions = new SessionOptions()) + using (var compileOptions = new OrtModelCompilationOptions(sessionOptions)) { + // Custom delegate that stores large initializers in a new file. + OrtExternalInitializerInfo BasicHandleInitializer( + string initializerName, IReadOnlyOrtValue initializerValue, + IReadOnlyExternalInitializerInfo originalInitializerLocation) + { + Assert.True(initializerName.Length > 0); + + var byteSize = initializerValue.GetTensorSizeInBytes(); + if (byteSize <= 64) + { + // Keep small initializers stored within model. + return null; + } + + long byteOffset = fs.Position; + ReadOnlySpan dataSpan = initializerValue.GetTensorDataAsSpan(); + fs.Write(dataSpan.ToArray(), 0, dataSpan.Length); // Write it out to a file + + // Return the data's new location. + return new OrtExternalInitializerInfo(initializersFilePath, byteOffset, byteSize); + } + + // Compile and generate an output model. + compileOptions.SetInputModelFromBuffer(model); + compileOptions.SetOutputModelPath(outputModelFilePath); + compileOptions.SetOutputModelGetInitializerLocationDelegate(BasicHandleInitializer); compileOptions.CompileModel(); - Assert.Fail("CompileModel() should have thrown an exception"); + Assert.True(File.Exists(outputModelFilePath)); } - catch (OnnxRuntimeException ex) + } + finally + { + if (File.Exists(outputModelFilePath)) { - Assert.Contains("exists already", ex.Message); + // This file is created by ORT, so we delete it manually in finally block. + File.Delete(outputModelFilePath); } + } + } + + [Fact] + public void GetInitializerLocationDelegateThatReusesExternalInitializers() + { + var model = TestDataLoader.LoadModelFromEmbeddedResource("conv_qdq_external_ini.onnx"); + var outputModelFilePath = "conv_qdq_external_ini.reuse.ctx.onnx"; + bool reusedExternalInitializers = false; + + try + { + using (var sessionOptions = new SessionOptions()) + using (var compileOptions = new OrtModelCompilationOptions(sessionOptions)) + { + // Custom delegate that reuses the original external initializer file. + OrtExternalInitializerInfo ReuseExternalInitializers( + string initializerName, IReadOnlyOrtValue initializerValue, + IReadOnlyExternalInitializerInfo originalInitializerLocation) + { + Assert.True(initializerName.Length > 0); + + if (originalInitializerLocation != null) + { + reusedExternalInitializers = true; // For test assertion only + string originalFilePath = originalInitializerLocation.GetFilePath(); + long originalFileOffset = originalInitializerLocation.GetFileOffset(); + long originalByteSize = originalInitializerLocation.GetByteSize(); + + Assert.True(originalFilePath.Length > 0); + Assert.True(originalFileOffset >= 0); + Assert.True(originalByteSize > 0); - if (File.Exists(output_model_file)) + // This initializer comes from an external file. Reuse it for compiled model. + return new OrtExternalInitializerInfo(originalFilePath, originalFileOffset, originalByteSize); + } + + // Otherwise, embed initializers that were not originally external. + return null; + } + + // Compile and generate an output model. + compileOptions.SetInputModelFromBuffer(model); + compileOptions.SetOutputModelPath(outputModelFilePath); + compileOptions.SetOutputModelGetInitializerLocationDelegate(ReuseExternalInitializers); + compileOptions.CompileModel(); + + Assert.True(File.Exists(outputModelFilePath)); + Assert.True(reusedExternalInitializers); + } + } + finally + { + if (File.Exists(outputModelFilePath)) { - File.Delete(output_model_file); + // This file is created by ORT, so we delete it manually in finally block. + File.Delete(outputModelFilePath); } } } diff --git a/csharp/testdata/conv_qdq_external_ini.bin b/csharp/testdata/conv_qdq_external_ini.bin new file mode 100644 index 0000000000000..89eea0dba1fa4 Binary files /dev/null and b/csharp/testdata/conv_qdq_external_ini.bin differ diff --git a/csharp/testdata/conv_qdq_external_ini.onnx b/csharp/testdata/conv_qdq_external_ini.onnx new file mode 100644 index 0000000000000..c53e1f3ad4d9b Binary files /dev/null and b/csharp/testdata/conv_qdq_external_ini.onnx differ diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 866892979b749..9a0708d72b4f8 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -1247,6 +1247,18 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi const std::filesystem::path& model_file_path, const ModelSavingOptions& model_saving_options) const; + /// + /// Serialize the Graph to a onnx::GraphProto. Caller provides a function that determines where each initializer + /// is stored (i.e., either in an external file or within the model). + /// + /// Function called for every initializer. + /// Opaque user state passed to the handle_initializer_func. + /// Output parameter set to the serialized onnx::GraphProto. + /// A status indicating success or an error. + common::Status ToGraphProtoWithCustomInitializerHandling(OrtGetInitializerLocationFunc handle_initializer_func, + void* state, + /*out*/ ONNX_NAMESPACE::GraphProto& graph_proto) const; + /** Gets the ISchemaRegistry instances being used with this Graph. */ IOnnxRuntimeOpSchemaCollectionPtr GetSchemaRegistry() const; @@ -1664,6 +1676,9 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi std::ostream& external_stream, int64_t& external_offset) const; + Status ToGraphProtoWithCustomInitializerHandlingImpl(OrtGetInitializerLocationFunc handle_initializer_func, + void* state, + /*out*/ ONNX_NAMESPACE::GraphProto& output_graph_proto) const; #endif Version IrVersion() const noexcept { diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 5d0c273a218fe..81caf5069bb6e 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -535,6 +535,57 @@ typedef OrtStatus*(ORT_API_CALL* EpSelectionDelegate)(_In_ const OrtEpDevice** e _Out_ size_t* num_selected, _In_ void* state); +/** \brief Function called by ORT to write a buffer to a custom destination (e.g., file, stream, etc.). + * + * \param state Opaque pointer holding the user's state. + * \param buffer The buffer to write. + * \param buffer_num_bytes The size of the buffer in bytes. + * + * \return OrtStatus* Write status. Return nullptr on success. + * Use CreateStatus to provide error info. Use ORT_FAIL as the error code. + * ORT will release the OrtStatus* if not null. + */ +typedef OrtStatus*(ORT_API_CALL* OrtWriteBufferFunc)(_In_ void* state, + _In_ const void* buffer, + _In_ size_t buffer_num_bytes); + +/** \brief Function called by ORT to allow user to specify how an initializer should be saved, that is, either + * written to an external file or stored within the model. ORT calls this function for every initializer when + * generating a model. + * + * If the function implementation sets the `new_external_info` output parameter to NULL, ORT stores the initializer data + * within the generated model. + * + * Otherwise, if the function implementation sets `new_external_info` to a valid OrtExternalInitializerInfo instance, + * ORT assumes that this function stores the initializer data in a file. In this case, ORT configures the model's + * initializer to point to the location specified by the `new_external_info` output parameter. + * + * \param[in] state Opaque pointer holding the user's state. + * \param[in] initializer_name The initializer's name as a null-terminated string. + * \param[in] initializer_value OrtValue containing the initializer's data, type, and shape. + * \param[in] external_info If the initializer is originally stored in an external file, `external_info` contains + * the file path, file offset, and the data's byte size within the file. Otherwise, + * `external_info` is NULL if the initializer is not originally stored in a file. + * \param[out] new_external_info Output parameter set to a new OrtExternalInitializerInfo instance indicating the + * location where the function implementation stored the initializer data. + * The function implementation must use `OrtApi::CreateExternalInitializerInfo()` to + * create the instance. + * If the function implementation sets `new_external_info` to NULL, + * ORT stores the initializers within the model. + * + * \note ORT takes ownership of the `new_external_info` output parameter. + * + * \return OrtStatus* Write status. Return nullptr on success. + * Use CreateStatus to provide error info. Use ORT_FAIL as the error code. + * ORT will release the OrtStatus* if not null. + */ +typedef OrtStatus*(ORT_API_CALL* OrtGetInitializerLocationFunc)( + _In_ void* state, + _In_ const char* initializer_name, + _In_ const OrtValue* initializer_value, + _In_opt_ const OrtExternalInitializerInfo* external_info, + _Outptr_result_maybenull_ OrtExternalInitializerInfo** new_external_info); + /** \brief Algorithm to use for cuDNN Convolution Op */ typedef enum OrtCudnnConvAlgoSearch { @@ -6509,6 +6560,26 @@ struct OrtApi { _In_ size_t num_ep_devices, _In_ const char* compatibility_info, _Out_ OrtCompiledModelCompatibility* out_status); + + /// \name OrtExternalInitializerInfo + /// @{ + + /** \brief Creates an OrtExternalInitializerInfo instance. + * + * \param[in] filepath The relative path to the file that stores the initializer's data. ORT copies this path string. + * \param[in] file_offset The byte offset where the initializer's data is stored within the file. + * \param[in] byte_size The size in bytes of the initializer's data within the file. + * \param[out] out Output parameter set to the new OrtExternalInitializerInfo instance. + * Must be released by calling ReleaseExternalInitializerInfo(). + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(CreateExternalInitializerInfo, _In_ const ORTCHAR_T* filepath, _In_ int64_t file_offset, + _In_ size_t byte_size, _Outptr_ OrtExternalInitializerInfo** out); + + /// @} }; /* @@ -7267,6 +7338,43 @@ struct OrtCompileApi { ORT_API2_STATUS(ModelCompilationOptions_SetGraphOptimizationLevel, _In_ OrtModelCompilationOptions* model_compile_options, _In_ GraphOptimizationLevel graph_optimization_level); + + /** \brief Sets a OrtWriteBufferFunc function that is called by ORT to write out the output model's serialized + * ONNX bytes. + * + * The provided write function may be called repeatedly until then entire output model has been written out. Each call + * to the write function is expected to consume the entire input buffer. + * + * The output model's destination (e.g., file path, memory buffer, or stream) can be set with any of the functions + * that begin with ModelCompilationOptions_SetOutputModel____. + * + * \param[in] model_compile_options The OrtModelCompilationOptions instance. + * \param[in] write_func The OrtWriteBufferFunc function called by ORT when writing out the model. + * \param[in] state Opaque state passed as the first argument to OrtWriteBufferFunc. Can be NULL. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(ModelCompilationOptions_SetOutputModelWriteFunc, + _In_ OrtModelCompilationOptions* model_compile_options, + _In_ OrtWriteBufferFunc write_func, _In_ void* state); + + /** \brief Sets a OrtGetInitializerLocationFunc function that is called by ORT for every initializer in the generated + * model. Allows implementer to specify whether initializers should be stored within the model or externally. + * + * \param[in] model_compile_options The OrtModelCompilationOptions instance. + * \param[in] get_initializer_location_func The OrtGetInitializerLocationFunc function called by ORT when + * to determine the location of the initializer. + * \param[in] state Opaque state passed as the first argument to OrtGetInitializerLocationFunc. Can be NULL. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc, + _In_ OrtModelCompilationOptions* model_compile_options, + _In_ OrtGetInitializerLocationFunc get_initializer_location_func, _In_ void* state); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index af0f5046a3f9f..4a8c67e2215ec 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -904,6 +904,9 @@ struct ConstExternalInitializerInfoImpl : Base { using ConstExternalInitializerInfo = detail::ConstExternalInitializerInfoImpl>; +/** \brief Wrapper around ::OrtExternalInitializerInfo + * + */ struct ExternalInitializerInfo : detail::ConstExternalInitializerInfoImpl { using Base = detail::ConstExternalInitializerInfoImpl; using Base::Base; @@ -913,6 +916,13 @@ struct ExternalInitializerInfo : detail::ConstExternalInitializerInfoImpl{p} {} ConstExternalInitializerInfo GetConst() const { return ConstExternalInitializerInfo{this->p_}; } + + ///< Wraps OrtApi::CreateExternalInitializerInfo + ExternalInitializerInfo(const ORTCHAR_T* filepath, int64_t file_offset, size_t byte_size); + + ///< Wrapper around CreateExternalInitializerInfo that does not throw an exception. + static Status Create(const ORTCHAR_T* filepath, int64_t file_offset, size_t byte_size, + /*out*/ ExternalInitializerInfo& out); }; namespace detail { @@ -1454,8 +1464,18 @@ struct ModelCompilationOptions : detail::Base { ModelCompilationOptions& SetOutputModelPath(const ORTCHAR_T* output_model_path); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelPath ModelCompilationOptions& SetOutputModelExternalInitializersFile(const ORTCHAR_T* file_path, size_t initializer_size_threshold); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelExternalInitializersFile + + ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc + ModelCompilationOptions& SetOutputModelGetInitializerLocationFunc( + OrtGetInitializerLocationFunc get_initializer_location_func, + void* state); + ModelCompilationOptions& SetOutputModelBuffer(OrtAllocator* allocator, void** output_model_buffer_ptr, size_t* output_model_buffer_size_ptr); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelBuffer + + ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelWriteFunc + ModelCompilationOptions& SetOutputModelWriteFunc(OrtWriteBufferFunc write_func, void* state); + ModelCompilationOptions& SetEpContextBinaryInformation(const ORTCHAR_T* output_directory, const ORTCHAR_T* model_name); ///< Wraps OrtApi::ModelCompilationOptions_SetEpContextBinaryInformation ModelCompilationOptions& SetFlags(uint32_t flags); ///< Wraps OrtApi::ModelCompilationOptions_SetFlags diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 30ff4753f42a8..59979189eed0f 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -589,6 +589,24 @@ inline size_t ConstExternalInitializerInfoImpl::GetByteSize() const { } } // namespace detail +inline ExternalInitializerInfo::ExternalInitializerInfo(const ORTCHAR_T* filepath, int64_t file_offset, + size_t byte_size) { + ThrowOnError(GetApi().CreateExternalInitializerInfo(filepath, file_offset, byte_size, &this->p_)); +} + +inline Status ExternalInitializerInfo::Create(const ORTCHAR_T* filepath, int64_t file_offset, size_t byte_size, + /*out*/ ExternalInitializerInfo& out) { + OrtExternalInitializerInfo* info = nullptr; + OrtStatus* status = GetApi().CreateExternalInitializerInfo(filepath, file_offset, byte_size, &info); + if (status != nullptr) { + return Status{status}; + } + + out = ExternalInitializerInfo(info); + + return Status{nullptr}; +} + namespace detail { template inline const char* KeyValuePairsImpl::GetValue(const char* key) const { @@ -1021,6 +1039,16 @@ inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelExternalI return *this; } +inline ModelCompilationOptions& +ModelCompilationOptions::SetOutputModelGetInitializerLocationFunc( + OrtGetInitializerLocationFunc get_initializer_location_func, void* state) { + Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc( + this->p_, + get_initializer_location_func, + state)); + return *this; +} + inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelBuffer( OrtAllocator* allocator, void** output_model_buffer_ptr, size_t* output_model_buffer_size_ptr) { Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetOutputModelBuffer(this->p_, allocator, @@ -1029,6 +1057,12 @@ inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelBuffer( return *this; } +inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelWriteFunc(OrtWriteBufferFunc write_func, + void* state) { + Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetOutputModelWriteFunc(this->p_, write_func, state)); + return *this; +} + inline ModelCompilationOptions& ModelCompilationOptions::SetEpContextEmbedMode( bool embed_ep_context_in_model) { Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetEpContextEmbedMode( diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index 38f034cf6d266..c0fe171f76037 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -33,6 +33,7 @@ OrtCompileApiFlags, # noqa: F401 OrtEpDevice, # noqa: F401 OrtExecutionProviderDevicePolicy, # noqa: F401 + OrtExternalInitializerInfo, # noqa: F401 OrtHardwareDevice, # noqa: F401 OrtHardwareDeviceType, # noqa: F401 OrtMemoryInfo, # noqa: F401 diff --git a/onnxruntime/core/framework/ep_context_options.cc b/onnxruntime/core/framework/ep_context_options.cc new file mode 100644 index 0000000000000..abfd3cf89cecf --- /dev/null +++ b/onnxruntime/core/framework/ep_context_options.cc @@ -0,0 +1,69 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include "core/common/common.h" +#include "core/framework/ep_context_options.h" +#include "core/session/onnxruntime_session_options_config_keys.h" + +namespace onnxruntime { +namespace epctx { +// class ModelGenOptions + +ModelGenOptions::ModelGenOptions() = default; + +ModelGenOptions::ModelGenOptions(const ConfigOptions& config_options) { + enable = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1"; + + std::string output_model_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); + if (!output_model_path.empty()) { + output_model_location = std::filesystem::path(output_model_path); + } else { + output_model_location = std::monostate{}; + } + + std::string external_initializers_file_path = config_options.GetConfigOrDefault( + kOrtSessionOptionsEpContextModelExternalInitializersFileName, ""); + if (!external_initializers_file_path.empty()) { + ExternalInitializerFileInfo ext_info = {}; + ext_info.file_path = external_initializers_file_path; + ext_info.size_threshold = 0; + initializers_location = std::move(ext_info); + } + + embed_ep_context_in_model = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "0") == "1"; +} + +bool ModelGenOptions::HasOutputModelLocation() const { + return !std::holds_alternative(output_model_location); +} + +const std::filesystem::path* ModelGenOptions::TryGetOutputModelPath() const { + return std::get_if(&output_model_location); +} + +const BufferHolder* ModelGenOptions::TryGetOutputModelBuffer() const { + return std::get_if(&output_model_location); +} + +const BufferWriteFuncHolder* ModelGenOptions::TryGetOutputModelWriteFunc() const { + return std::get_if(&output_model_location); +} + +bool ModelGenOptions::AreInitializersEmbeddedInOutputModel() const { + return std::holds_alternative(initializers_location); +} + +const ExternalInitializerFileInfo* ModelGenOptions::TryGetExternalInitializerFileInfo() const { + return std::get_if(&initializers_location); +} + +const InitializerHandler* ModelGenOptions::TryGetInitializerHandler() const { + return std::get_if(&initializers_location); +} + +} // namespace epctx +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/ep_context_options.h b/onnxruntime/core/framework/ep_context_options.h new file mode 100644 index 0000000000000..6643516bfb4c3 --- /dev/null +++ b/onnxruntime/core/framework/ep_context_options.h @@ -0,0 +1,98 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include "core/framework/allocator.h" +#include "core/framework/config_options.h" + +namespace onnxruntime { +namespace epctx { +/// +/// Holds the buffer that will store the output model and the allocator used to allocate the memory. +/// +struct BufferHolder { + void** buffer_ptr = nullptr; + size_t* buffer_size_ptr = nullptr; + AllocatorPtr buffer_allocator = nullptr; +}; + +/// +/// Holds the opaque stream state and the write function that ORT calls to write out the output model. +/// +struct BufferWriteFuncHolder { + OrtWriteBufferFunc write_func = nullptr; + void* stream_state = nullptr; // Opaque pointer to user's stream state. Passed as first argument to write_func. +}; + +/// +/// Holds path and size threshold used to write out initializers to an external file. +/// +struct ExternalInitializerFileInfo { + std::filesystem::path file_path; + size_t size_threshold = 0; +}; + +/// +/// Holds function and state provided by user to handle initializer data (i.e., write to stream or embed in model). +/// +struct InitializerHandler { + OrtGetInitializerLocationFunc handle_initializer_func = nullptr; + void* state = nullptr; +}; + +/// +/// Stores EPContext model generation options. Used in SessionOptions. +/// +struct ModelGenOptions { + // Action to take if the output model does not have compiled (EPContext) nodes. + enum class ActionIfNoCompiledNodes { + // Return OK() but don't generate an output model. Compiling via SessionOptions defaults to this behavior + // to maintain compatibility. The explicit compile API does *not* use this action. + kDontGenerateModel = 0, + + // Generate an output model even if it doesn't have compiled nodes. + // The explicit Compile API defaults to this value. + kGenerateModel, + + // Return an error if the model does not have compiled nodes. + // The explicit Compile API can be configured to this value. + kReturnError, + }; + + ModelGenOptions(); + + // Initializes from string key/value pairs in session config options. + explicit ModelGenOptions(const ConfigOptions& config_options); + + bool enable = false; + bool error_if_output_file_exists = true; + bool error_if_no_compiled_nodes = false; + bool embed_ep_context_in_model = false; + ActionIfNoCompiledNodes action_if_no_compiled_nodes = ActionIfNoCompiledNodes::kDontGenerateModel; + + std::variant // Function to write the output model to a user's stream. + output_model_location = std::monostate{}; + + std::variant // Custom function called for every initializer to determine location. + initializers_location = std::monostate{}; + + bool HasOutputModelLocation() const; + const std::filesystem::path* TryGetOutputModelPath() const; + const BufferHolder* TryGetOutputModelBuffer() const; + const BufferWriteFuncHolder* TryGetOutputModelWriteFunc() const; + + bool AreInitializersEmbeddedInOutputModel() const; + const ExternalInitializerFileInfo* TryGetExternalInitializerFileInfo() const; + const InitializerHandler* TryGetInitializerHandler() const; +}; + +} // namespace epctx +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/ep_context_utils.cc b/onnxruntime/core/framework/ep_context_utils.cc new file mode 100644 index 0000000000000..3f02c54538526 --- /dev/null +++ b/onnxruntime/core/framework/ep_context_utils.cc @@ -0,0 +1,126 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) +#include +#include +#include "core/framework/ep_context_utils.h" +#include "core/framework/error_code_helper.h" +#include "core/graph/model_saving_options.h" + +namespace onnxruntime { +namespace epctx { + +// Serialize an EPContext model into a onnx::ModelProto. +Status EpContextModelToProto(const onnxruntime::Model& ep_context_model, + const std::filesystem::path& validated_model_path, + const epctx::ModelGenOptions& ep_context_gen_options, + /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) { + // Handle case where initializers are stored inline within the ONNX model. + if (ep_context_gen_options.AreInitializersEmbeddedInOutputModel()) { + // if no external ini file specified, set force_embed_external_ini to true to avoid intermediate file creation + // and force all initializers embed into the ONNX file. + ModelSavingOptions model_saving_options{/*size_threshold*/ SIZE_MAX}; + model_saving_options.force_embed_external_ini = true; + + model_proto = ep_context_model.ToGraphProtoWithExternalInitializers(std::filesystem::path{}, + validated_model_path, + model_saving_options); + return Status::OK(); + } + + // Handle case where initializers (with size > threshold) are stored in an external file. + if (const epctx::ExternalInitializerFileInfo* ext_info = ep_context_gen_options.TryGetExternalInitializerFileInfo(); + ext_info != nullptr) { + ModelSavingOptions model_saving_options{ext_info->size_threshold}; + + model_proto = ep_context_model.ToGraphProtoWithExternalInitializers(ext_info->file_path, + validated_model_path, + model_saving_options); + return Status::OK(); + } + + // Handle case where user specified a custom handler function that determines how each initializer is saved. + if (const epctx::InitializerHandler* custom_handler = ep_context_gen_options.TryGetInitializerHandler(); + custom_handler != nullptr) { + ORT_RETURN_IF_ERROR(ep_context_model.ToGraphProtoWithCustomInitializerHandling( + custom_handler->handle_initializer_func, + custom_handler->state, + model_proto)); + return Status::OK(); + } + + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unexpected location for initializers while generating ", + validated_model_path); +} + +// +// OutStreamBuf class: +// + +OutStreamBuf::OutStreamBuf(BufferWriteFuncHolder write_func_holder) + : write_func_holder_(write_func_holder), buffer_(65536) { + setp(buffer_.data(), buffer_.data() + buffer_.size()); +} + +OutStreamBuf::~OutStreamBuf() { + sync(); +} + +// Called when the buffer_ is full. Flushes the buffer_ (via sync()) and then writes the overflow character to buffer_. +std::streambuf::int_type OutStreamBuf::overflow(std::streambuf::int_type ch) { + if (sync() == -1) { + return traits_type::eof(); + } + + if (ch != traits_type::eof()) { + *pptr() = static_cast(ch); + pbump(1); + } + + return ch; +} + +// Flushes the entire buffer_ to the user's write function. +int OutStreamBuf::sync() { + if (!last_status_.IsOK()) { + return -1; + } + + std::ptrdiff_t num_bytes = pptr() - pbase(); + if (num_bytes == 0) { + return 0; + } + + // Can only call pbump() with an int, so can only write at most (2^31 - 1) bytes. + if (num_bytes > std::numeric_limits::max()) { + num_bytes = std::numeric_limits::max(); + } + + char* ptr = pbase(); + + Status status = Status::OK(); + + ORT_TRY { + status = ToStatusAndRelease(write_func_holder_.write_func(write_func_holder_.stream_state, + ptr, num_bytes)); + } + ORT_CATCH(const std::exception& e) { + ORT_HANDLE_EXCEPTION([&]() { + status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Caught exception while calling user's OrtOutStreamWriteFunc callback: ", e.what()); + }); + } + + if (!status.IsOK()) { + last_status_ = std::move(status); + return -1; + } + + pbump(-static_cast(num_bytes)); // Reset internal pointer to point to the beginning of the buffer_ + return 0; +} + +} // namespace epctx +} // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/framework/ep_context_utils.h b/onnxruntime/core/framework/ep_context_utils.h new file mode 100644 index 0000000000000..b3c76565982ff --- /dev/null +++ b/onnxruntime/core/framework/ep_context_utils.h @@ -0,0 +1,61 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if !defined(ORT_MINIMAL_BUILD) + +#include +#include +#include + +#include "core/common/status.h" +#include "core/framework/ep_context_options.h" +#include "core/graph/model.h" + +namespace onnxruntime { +namespace epctx { + +/// +/// Serialize an EPContext model into a onnx::ModelProto based on the provided options. +/// +/// The EP Context model to serialize. +/// The path into which to save the model. May be empty if serialized into a +/// buffer or output stream. +/// The model generation options. +/// Output parameter set to the serialized onnx::ModelProto. +/// A status indicating success or an error. +Status EpContextModelToProto(const onnxruntime::Model& ep_context_model, + const std::filesystem::path& validated_model_path, + const epctx::ModelGenOptions& ep_context_gen_options, + /*out*/ ONNX_NAMESPACE::ModelProto& model_proto); + +// Class that wraps the user's OrtBufferWriteFunc function to enable use with +// C++'s std::ostream. +// Example: +// BufferWriteFuncHolder write_func_holder{write_func, stream_state}; +// std::unique_ptr out_stream_buf = std::make_unique(write_func_holder); +// std::ostream out_stream(out_stream_buf.get()); +class OutStreamBuf : public std::streambuf { + public: + explicit OutStreamBuf(BufferWriteFuncHolder write_func_holder); + ~OutStreamBuf(); + + const Status& GetStatus() const { + return last_status_; + } + + protected: + int_type overflow(int_type ch) override; + int sync() override; + + private: + BufferWriteFuncHolder write_func_holder_{}; + std::vector buffer_; + Status last_status_{}; +}; + +} // namespace epctx +} // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 421e5a6db51b7..43caf4766d5c0 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -5,10 +5,12 @@ #include #include +#include #include "core/common/inlined_containers.h" #include "core/common/string_utils.h" #include "core/framework/compute_capability.h" +#include "core/framework/ep_context_utils.h" #include "core/framework/execution_providers.h" #include "core/framework/func_kernel.h" #include "core/framework/kernel_lookup.h" @@ -20,9 +22,9 @@ #include "core/graph/graph_utils.h" #include "core/graph/graph_viewer.h" #include "core/graph/model.h" -#include "core/graph/model_saving_options.h" -#include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" +#include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/util/protobuf_parsing_utils.h" // uncomment this line to count non-CUDA ops in ONNX domain // #define COUNT_NON_CUDA_OPS @@ -766,6 +768,7 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide } // Validate the ep_context_path to make sure it is file path and check whether the file exist already +// TODO: Move function to ep_context_utils.h/cc static Status GetValidatedEpContextPath(const std::filesystem::path& ep_context_path, const std::filesystem::path& model_path, std::filesystem::path& context_cache_path, @@ -794,9 +797,10 @@ static Status GetValidatedEpContextPath(const std::filesystem::path& ep_context_ return Status::OK(); } +// TODO: Move function to ep_context_utils.h/cc static Status CreateEpContextModel(const ExecutionProviders& execution_providers, const Graph& graph, - const EpContextModelGenerationOptions& ep_context_gen_options, + const epctx::ModelGenOptions& ep_context_gen_options, const logging::Logger& logger) { InlinedVector all_ep_context_nodes; for (const auto& ep : execution_providers) { @@ -807,11 +811,11 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers if (all_ep_context_nodes.size() < 1) { auto action_if_no_compiled_nodes = ep_context_gen_options.action_if_no_compiled_nodes; - ORT_RETURN_IF(action_if_no_compiled_nodes == EpContextModelGenerationOptions::ActionIfNoCompiledNodes::kReturnError, + ORT_RETURN_IF(action_if_no_compiled_nodes == epctx::ModelGenOptions::ActionIfNoCompiledNodes::kReturnError, "Unable to compile any nodes. Check that the session EPs support compilation and can execute " "at least one subgraph in the model."); - if (action_if_no_compiled_nodes == EpContextModelGenerationOptions::ActionIfNoCompiledNodes::kDontGenerateModel) { + if (action_if_no_compiled_nodes == epctx::ModelGenOptions::ActionIfNoCompiledNodes::kDontGenerateModel) { LOGS(logger, WARNING) << "Unable to compile any nodes. ONNX Runtime will not generate a compiled model. " "Either the session EPs do not support compilation or the model is already compiled."; // Note: this path is only taken if a model is compiled with the original compilation approach that uses @@ -821,7 +825,7 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers } // Assert so that this is caught in a test in DEBUG builds (in case a new enum value is added) - assert(action_if_no_compiled_nodes == EpContextModelGenerationOptions::ActionIfNoCompiledNodes::kGenerateModel); + assert(action_if_no_compiled_nodes == epctx::ModelGenOptions::ActionIfNoCompiledNodes::kGenerateModel); LOGS(logger, INFO) << "Unable to compile any nodes but will still generate an output model. " "Either the session EPs do not support compilation or the model is already compiled."; } @@ -835,15 +839,17 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers return std::make_pair(false, static_cast(nullptr)); }; - bool saving_to_buffer = ep_context_gen_options.output_model_buffer_ptr != nullptr && - ep_context_gen_options.output_model_buffer_size_ptr != nullptr && - ep_context_gen_options.output_model_buffer_allocator != nullptr; + const epctx::BufferHolder* output_buffer_holder = ep_context_gen_options.TryGetOutputModelBuffer(); + const epctx::BufferWriteFuncHolder* output_write_func_holder = ep_context_gen_options.TryGetOutputModelWriteFunc(); + const std::filesystem::path* output_model_path_ptr = ep_context_gen_options.TryGetOutputModelPath(); - std::filesystem::path context_cache_path; - if (!saving_to_buffer || !graph.ModelPath().empty()) { - ORT_RETURN_IF_ERROR(GetValidatedEpContextPath(ep_context_gen_options.output_model_file_path, + std::filesystem::path valid_output_model_path; + if (output_model_path_ptr != nullptr || !graph.ModelPath().empty()) { + std::filesystem::path output_model_path = (output_model_path_ptr != nullptr) ? *output_model_path_ptr + : std::filesystem::path(""); + ORT_RETURN_IF_ERROR(GetValidatedEpContextPath(output_model_path, graph.ModelPath(), - context_cache_path, + valid_output_model_path, ep_context_gen_options.error_if_output_file_exists)); } @@ -910,10 +916,11 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers } } + ORT_RETURN_IF_ERROR(ep_graph.Resolve()); + // Generate EP compatibility strings for OrtEp types and add to model metadata // At this point, the graph has been populated with all the EPContext nodes { - ORT_RETURN_IF_ERROR(ep_graph.Resolve()); const GraphViewer graph_viewer(ep_graph); for (const auto& ep : execution_providers) { try { @@ -938,39 +945,60 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers } } - size_t ini_size_threshold = ep_context_gen_options.output_external_initializer_size_threshold; - std::filesystem::path external_ini_path = ep_context_gen_options.output_external_initializers_file_path; - bool force_embed_external_ini = false; - if (external_ini_path.empty()) { - // if no external ini file specified, set force_embed_external_ini to true to avoid intermedia file creation - // and force all initializers embed into the Onnx file - ini_size_threshold = SIZE_MAX; - force_embed_external_ini = true; - } - - ModelSavingOptions model_saving_options{ini_size_threshold}; - model_saving_options.force_embed_external_ini = force_embed_external_ini; + ONNX_NAMESPACE::ModelProto model_proto; + ORT_RETURN_IF_ERROR(EpContextModelToProto(ep_context_model, valid_output_model_path, ep_context_gen_options, + /*out*/ model_proto)); - if (saving_to_buffer) { - ORT_RETURN_IF_ERROR(ep_context_model.MainGraph().Resolve()); - // TODO(adrianlizarraga): Investigate if we can make this more memory efficient. - // May be able to use allocator to directly allocate the ModelProto to avoid a copy. - ONNX_NAMESPACE::ModelProto model_proto = ep_context_model.ToGraphProtoWithExternalInitializers(external_ini_path, - context_cache_path, - model_saving_options); + if (output_buffer_holder != nullptr) { + // Write output model into a buffer ORT allocates for the user. size_t buffer_size = model_proto.ByteSizeLong(); ORT_RETURN_IF(buffer_size > static_cast(std::numeric_limits::max()), "Cannot serialize ONNX ModelProto larger than 2GB"); - AllocatorPtr allocator = ep_context_gen_options.output_model_buffer_allocator; + AllocatorPtr allocator = output_buffer_holder->buffer_allocator; IAllocatorUniquePtr buffer = IAllocator::MakeUniquePtr(allocator, buffer_size); model_proto.SerializeToArray(buffer.get(), static_cast(buffer_size)); - *ep_context_gen_options.output_model_buffer_size_ptr = buffer_size; - *ep_context_gen_options.output_model_buffer_ptr = buffer.release(); + *output_buffer_holder->buffer_size_ptr = buffer_size; + *output_buffer_holder->buffer_ptr = buffer.release(); + } else if (output_write_func_holder != nullptr) { + // Write output model to user's output stream. + size_t buffer_size = model_proto.ByteSizeLong(); + ORT_RETURN_IF(buffer_size > static_cast(std::numeric_limits::max()), + "Cannot serialize ONNX ModelProto larger than 2GB"); + + auto out_stream_buf = std::make_unique(*output_write_func_holder); + std::ostream out_stream(out_stream_buf.get()); + + model_proto.SerializeToOstream(&out_stream); + out_stream.flush(); + ORT_RETURN_IF_ERROR(out_stream_buf->GetStatus()); } else { - ORT_RETURN_IF_ERROR(Model::SaveWithExternalInitializers(ep_context_model, context_cache_path, - external_ini_path, model_saving_options)); + // Write output model to a file. + int fd = 0; + Status status = Env::Default().FileOpenWr(valid_output_model_path, fd); + ORT_RETURN_IF_ERROR(status); + + ORT_TRY { + google::protobuf::io::FileOutputStream output(fd); + bool serialize_result = model_proto.SerializeToZeroCopyStream(&output) && output.Flush(); + if (!serialize_result) { + status = ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_PROTOBUF, + "Protobuf serialization failed when generating EPContext model ", + valid_output_model_path); + } + } + ORT_CATCH(const std::exception& ex) { + ORT_HANDLE_EXCEPTION([&]() { + status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, ex.what()); + }); + } + if (!status.IsOK()) { + GSL_SUPPRESS(es .84) + ORT_IGNORE_RETURN_VALUE(Env::Default().FileClose(fd)); + return status; + } + ORT_RETURN_IF_ERROR(Env::Default().FileClose(fd)); } return Status::OK(); @@ -1221,7 +1249,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, const ConfigOptions& config_options, const logging::Logger& logger, Mode mode, - const EpContextModelGenerationOptions& ep_context_gen_options, + const epctx::ModelGenOptions& ep_context_gen_options, const layout_transformation::DebugGraphFn& debug_graph_fn) const { // It is a greedy partitioning algorithm per provider preferences user provided when calling ONNX RUNTIME right now. // 1. Execution providers' capabilities are checked one by one. @@ -1268,12 +1296,15 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, if (mode == Mode::kNormal || mode == Mode::kAssignOnly) { #if !defined(ORT_MINIMAL_BUILD) - if (ep_context_gen_options.enable && ep_context_gen_options.output_model_buffer_ptr == nullptr) { - // Check before EP compile graphs - std::filesystem::path context_cache_path; - ORT_RETURN_IF_ERROR(GetValidatedEpContextPath(ep_context_gen_options.output_model_file_path, graph.ModelPath(), - context_cache_path, - ep_context_gen_options.error_if_output_file_exists)); + if (ep_context_gen_options.enable) { + if (const std::filesystem::path* output_model_path_ptr = ep_context_gen_options.TryGetOutputModelPath(); + output_model_path_ptr != nullptr) { + // Check before EP compile graphs + std::filesystem::path context_cache_path; + ORT_RETURN_IF_ERROR(GetValidatedEpContextPath(*output_model_path_ptr, graph.ModelPath(), + context_cache_path, + ep_context_gen_options.error_if_output_file_exists)); + } } // We use this only if Resource Aware Partitioning is enabled for any of the EPs diff --git a/onnxruntime/core/framework/graph_partitioner.h b/onnxruntime/core/framework/graph_partitioner.h index 6e36d79701fd7..abe46cea58ab2 100644 --- a/onnxruntime/core/framework/graph_partitioner.h +++ b/onnxruntime/core/framework/graph_partitioner.h @@ -15,7 +15,10 @@ class ExecutionProviders; class KernelRegistryManager; class Model; struct ConfigOptions; -struct EpContextModelGenerationOptions; + +namespace epctx { +struct ModelGenOptions; +} class GraphPartitioner { public: @@ -50,7 +53,7 @@ class GraphPartitioner { const ConfigOptions& config_options, const logging::Logger& logger, Mode mode = Mode::kNormal, - const EpContextModelGenerationOptions& ep_context_gen_options = {}, + const epctx::ModelGenOptions& ep_context_gen_options = {}, const layout_transformation::DebugGraphFn& debug_graph_fn = {}) const; bool IsLoadCancellationFlagSet() const { diff --git a/onnxruntime/core/framework/session_options.cc b/onnxruntime/core/framework/session_options.cc index 231eb47603838..63f928d52d788 100644 --- a/onnxruntime/core/framework/session_options.cc +++ b/onnxruntime/core/framework/session_options.cc @@ -99,20 +99,11 @@ void SessionOptions::AddCustomOpLibraryHandle(PathString library_name, void* lib } #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) -EpContextModelGenerationOptions::EpContextModelGenerationOptions(const ConfigOptions& config_options) { - enable = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1"; - output_model_file_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); - output_external_initializers_file_path = config_options.GetConfigOrDefault( - kOrtSessionOptionsEpContextModelExternalInitializersFileName, ""); - output_external_initializer_size_threshold = 0; - embed_ep_context_in_model = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "0") == "1"; -} - -EpContextModelGenerationOptions SessionOptions::GetEpContextGenerationOptions() const { +epctx::ModelGenOptions SessionOptions::GetEpContextGenerationOptions() const { if (this->has_explicit_ep_context_gen_options) { return this->ep_context_gen_options; } - return EpContextModelGenerationOptions(this->config_options); + return epctx::ModelGenOptions(this->config_options); } } // namespace onnxruntime diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h index b75eeb217e7f0..b328fc916f885 100644 --- a/onnxruntime/core/framework/session_options.h +++ b/onnxruntime/core/framework/session_options.h @@ -13,6 +13,7 @@ #include "core/common/inlined_containers.h" #include "core/framework/allocator.h" #include "core/framework/config_options.h" +#include "core/framework/ep_context_options.h" #include "core/framework/ort_value.h" #include "core/session/onnxruntime_c_api.h" #include "core/optimizer/graph_transformer_level.h" @@ -70,53 +71,6 @@ struct FreeDimensionOverride { using CheckLoadCancellationFn = std::function; -/// -/// Options that configure the generation of a compiled model (i.e., a model with EPContext nodes). -/// There are two ways to compile a model: -/// 1. By specifying the correct session option configurations and creating an inference session. -/// The compiled model is generated as a side-effect of session creation. -/// 2. Using an explicit compile API (see OrtCompileApi struct in onnxruntime_c_api.h). -/// -/// The default values in this struct are set to match the current/default behavior of approach 1 to maintain -/// compatibility with the older way of compiling. The explicit compile API overrides some of these values to -/// provide its own defaults (see core/session/model_compilation_options.h/cc). -/// -struct EpContextModelGenerationOptions { - // Action to take if the output model does not have compiled (EPContext) nodes. - enum class ActionIfNoCompiledNodes { - // Return OK() but don't generate an output model. Compiling via SessionOptions defaults to this behavior - // to maintain compatibility. The explicit compile API does *not* use this action. - kDontGenerateModel = 0, - - // Generate an output model even if it doesn't have compiled nodes. - // The explicit Compile API defaults to this value. - kGenerateModel, - - // Return an error if the model does not have compiled nodes. - // The explicit Compile API can be configured to this value. - kReturnError, - }; - - EpContextModelGenerationOptions() = default; - - // Initializes from string key/value pairs in session config options. - // This initializes this struct from options set via the older compiling approach #1 above. - explicit EpContextModelGenerationOptions(const ConfigOptions& config_options); - - bool enable = false; - bool error_if_output_file_exists = true; - ActionIfNoCompiledNodes action_if_no_compiled_nodes = ActionIfNoCompiledNodes::kDontGenerateModel; - bool embed_ep_context_in_model = false; - - std::string output_model_file_path; - void** output_model_buffer_ptr = nullptr; - size_t* output_model_buffer_size_ptr = nullptr; - AllocatorPtr output_model_buffer_allocator = nullptr; - - std::string output_external_initializers_file_path; - size_t output_external_initializer_size_threshold = 0; -}; - struct EpSelectionPolicy { // flag to detect that a policy was set by the user. // need to preserve current behavior of defaulting to CPU EP if no EPs are explicitly registered @@ -270,8 +224,8 @@ struct SessionOptions { // The function GetEpContextGenerationOptions() handles conversion of string key/value pairs to the new // struct type. bool has_explicit_ep_context_gen_options = false; - EpContextModelGenerationOptions ep_context_gen_options = {}; - EpContextModelGenerationOptions GetEpContextGenerationOptions() const; + epctx::ModelGenOptions ep_context_gen_options = {}; + epctx::ModelGenOptions GetEpContextGenerationOptions() const; }; inline std::ostream& operator<<(std::ostream& os, const SessionOptions& session_options) { diff --git a/onnxruntime/core/framework/tensor_external_data_info.cc b/onnxruntime/core/framework/tensor_external_data_info.cc index d7f5b23d56c70..dfdb3ba962609 100644 --- a/onnxruntime/core/framework/tensor_external_data_info.cc +++ b/onnxruntime/core/framework/tensor_external_data_info.cc @@ -18,6 +18,13 @@ using ::google::protobuf::RepeatedPtrField; using ::ONNX_NAMESPACE::StringStringEntryProto; namespace onnxruntime { +ExternalDataInfo::ExternalDataInfo() = default; + +#if !defined(ORT_MINIMAL_BUILD) +ExternalDataInfo::ExternalDataInfo(const PathString& rel_path, OFFSET_TYPE offset, size_t length) + : rel_path_(rel_path), offset_(offset), length_(length) {} +#endif + Status ExternalDataInfo::Create(const RepeatedPtrField& input, std::unique_ptr& external_data_info_result) { auto external_data_info = std::make_unique(); diff --git a/onnxruntime/core/framework/tensor_external_data_info.h b/onnxruntime/core/framework/tensor_external_data_info.h index 784b3f352a78e..aa9bb32922bd7 100644 --- a/onnxruntime/core/framework/tensor_external_data_info.h +++ b/onnxruntime/core/framework/tensor_external_data_info.h @@ -25,6 +25,12 @@ class ExternalDataInfo { using OFFSET_TYPE = off_t; #endif + ExternalDataInfo(); + +#if !defined(ORT_MINIMAL_BUILD) + ExternalDataInfo(const PathString& rel_path, OFFSET_TYPE offset, size_t length); +#endif + const PathString& GetRelPath() const { return rel_path_; } OFFSET_TYPE GetOffset() const { return offset_; } diff --git a/onnxruntime/core/graph/abi_graph_types.h b/onnxruntime/core/graph/abi_graph_types.h index 2ef7c4a9091f3..c5d7d4cc4e68c 100644 --- a/onnxruntime/core/graph/abi_graph_types.h +++ b/onnxruntime/core/graph/abi_graph_types.h @@ -31,8 +31,10 @@ enum class OrtGraphIrApi { kEpApi, }; -// Alias OrtExternalInitializerInfo to the internal type. -struct OrtExternalInitializerInfo : onnxruntime::ExternalDataInfo {}; +// Alias OrtExternalInitializerInfo to the internal onnxruntime::ExternalDataInfo type. +struct OrtExternalInitializerInfo : onnxruntime::ExternalDataInfo { + using onnxruntime::ExternalDataInfo::ExternalDataInfo; // inherit constructors +}; /// /// Public type that represents an ONNX value info. diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 0a228176175eb..9a97711996343 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -17,6 +17,7 @@ #include "core/common/logging/logging.h" #include "core/common/narrow.h" #include "core/flatbuffers/flatbuffers_utils.h" +#include "core/framework/error_code_helper.h" #include "core/framework/tensor_type_and_shape.h" #include "core/flatbuffers/schema/ort.fbs.h" #include "core/framework/tensor_external_data_info.h" @@ -4357,14 +4358,23 @@ Status Graph::RegenerateInitializersAndReplaceInMemory(gsl::span& subgraphs) { + for (const auto& node : nodes) { if (node.ContainsSubgraph()) { // Let's find this node in the output_graph_proto // The node name is optional, so we may need to check by the output value name // given that they can only assigned once. - auto hit = std::find_if(output_graph_proto.mutable_node()->begin(), - output_graph_proto.mutable_node()->end(), + auto hit = std::find_if(graph_proto.mutable_node()->begin(), + graph_proto.mutable_node()->end(), [&node](const ONNX_NAMESPACE::NodeProto& proto) { const auto& node_name = node.Name(); if (!node_name.empty()) @@ -4372,7 +4382,7 @@ Status Graph::ProcessSubgraphsInMemoryData(ONNX_NAMESPACE::GraphProto& output_gr return (proto.output_size() > 0 && proto.output(0) == node.OutputDefs()[0]->Name()); }); - ORT_RETURN_IF_NOT(hit != output_graph_proto.mutable_node()->end(), "Node ", node.Name(), + ORT_RETURN_IF_NOT(hit != graph_proto.mutable_node()->end(), "Node ", node.Name(), " not found in output_graph_proto"); auto& result_node = *hit; for (const auto& e : node.GetAttributeNameToSubgraphMap()) { @@ -4387,12 +4397,28 @@ Status Graph::ProcessSubgraphsInMemoryData(ONNX_NAMESPACE::GraphProto& output_gr ORT_RETURN_IF_NOT(sub_hit != result_node.mutable_attribute()->end() && utils::HasGraph(*sub_hit), "Subgraph ", name, " is referred to in GetAttributeNameToSubgraphMap, but not found in node ", node.Name(), " while attempting to recurse into it."); - auto& result_subgraph = *sub_hit->mutable_g(); - ORT_RETURN_IF_ERROR(subgraph->ProcessSubgraphsInMemoryData(result_subgraph)); + SubgraphWithMutableProto subgraph_result{sub_hit->mutable_g(), subgraph}; + subgraphs.emplace_back(subgraph_result); } } } + return Status::OK(); +} + +Status Graph::ProcessSubgraphsInMemoryData(ONNX_NAMESPACE::GraphProto& output_graph_proto) const { + // Process subgraphs recursively (bottom-up). + { + std::vector subgraphs; + ORT_RETURN_IF_ERROR(GetSubgraphsWithMatchingGraphProtos(Nodes(), output_graph_proto, subgraphs)); + + for (SubgraphWithMutableProto& subgraph_and_proto : subgraphs) { + gsl::not_null subgraph = subgraph_and_proto.subgraph; + gsl::not_null subgraph_proto = subgraph_and_proto.subgraph_proto; + ORT_RETURN_IF_ERROR(subgraph->ProcessSubgraphsInMemoryData(*subgraph_proto)); + } + } + // Filter in iterators for weights that are present in the name_to_initial_tensor_ map // and preserve the order. This is needed for tests. InlinedVector initializers_to_process; @@ -4444,44 +4470,19 @@ Status Graph::AddExternalInitializersToGraphProtoImpl( // Process initializers in a subgraph, check their size and // write to an external file. This function also saves pre-packed // blobs for the initializer being saved to disk, if the initializer has any pre-packs. - // This function is invoked by ToGraphProtoWithExternalInitiallizers() and processes subgraphs + // This function is invoked by ToGraphProtoWithExternalInitializers() and processes subgraphs // bottom up. - for (const auto& node : Nodes()) { - if (node.ContainsSubgraph()) { - // Let's find this node in the output_graph_proto - // The node name is optional, so we may need to check by the output value name - // given that they can only assigned once. - auto hit = std::find_if(output_graph_proto.mutable_node()->begin(), - output_graph_proto.mutable_node()->end(), - [&node](const ONNX_NAMESPACE::NodeProto& proto) { - const auto& node_name = node.Name(); - if (!node_name.empty()) - return proto.name() == node_name; - return (proto.output_size() > 0 && - proto.output(0) == node.OutputDefs()[0]->Name()); - }); - ORT_RETURN_IF_NOT(hit != output_graph_proto.mutable_node()->end(), "Node ", node.Name(), - " not found in output_graph_proto"); - auto& result_node = *hit; - for (const auto& e : node.GetAttributeNameToSubgraphMap()) { - const auto& name = e.first; - const auto& subgraph = e.second; - // Lets find this subgraph in the result_node - auto sub_hit = std::find_if(result_node.mutable_attribute()->begin(), - result_node.mutable_attribute()->end(), - [&name](const ONNX_NAMESPACE::AttributeProto& proto) { - return proto.name() == name; - }); - ORT_RETURN_IF_NOT(sub_hit != result_node.mutable_attribute()->end() && utils::HasGraph(*sub_hit), - "Subgraph ", name, " is referred to in GetAttributeNameToSubgraphMap, but not found in node ", - node.Name(), " while attempting to recurse into it."); - auto& result_subgraph = *sub_hit->mutable_g(); - ORT_RETURN_IF_ERROR(subgraph->AddExternalInitializersToGraphProtoImpl( - model_path, external_file_path, - model_external_file_path, model_saving_options, - result_subgraph, - external_stream, external_offset)); - } + { + std::vector subgraphs; + ORT_RETURN_IF_ERROR(GetSubgraphsWithMatchingGraphProtos(Nodes(), output_graph_proto, subgraphs)); + + for (SubgraphWithMutableProto& subgraph_and_proto : subgraphs) { + gsl::not_null subgraph = subgraph_and_proto.subgraph; + gsl::not_null subgraph_proto = subgraph_and_proto.subgraph_proto; + ORT_RETURN_IF_ERROR(subgraph->AddExternalInitializersToGraphProtoImpl( + model_path, external_file_path, + model_external_file_path, model_saving_options, + *subgraph_proto, external_stream, external_offset)); } } @@ -4643,6 +4644,113 @@ ONNX_NAMESPACE::GraphProto Graph::ToGraphProtoWithExternalInitializers( return result; } +Status Graph::ToGraphProtoWithCustomInitializerHandlingImpl( + OrtGetInitializerLocationFunc handle_initializer_func, + void* state, + /*out*/ ONNX_NAMESPACE::GraphProto& output_graph_proto) const { + // This loop processes subgraphs bottom up. + { + std::vector subgraphs; + ORT_RETURN_IF_ERROR(GetSubgraphsWithMatchingGraphProtos(Nodes(), output_graph_proto, subgraphs)); + + for (SubgraphWithMutableProto& subgraph_and_proto : subgraphs) { + gsl::not_null subgraph = subgraph_and_proto.subgraph; + gsl::not_null subgraph_proto = subgraph_and_proto.subgraph_proto; + ORT_RETURN_IF_ERROR(subgraph->ToGraphProtoWithCustomInitializerHandlingImpl(handle_initializer_func, + state, *subgraph_proto)); + } + } + + // Create a sorted std::vector of initializers so that we always process them in a deterministic order. + InlinedVector initializers; + initializers.reserve(GetAllInitializedTensors().size()); + + for (const auto& [name, initializer_tp] : GetAllInitializedTensors()) { + initializers.push_back(initializer_tp); + } + + std::sort(initializers.begin(), initializers.end(), + [](const ONNX_NAMESPACE::TensorProto* a, const ONNX_NAMESPACE::TensorProto* b) { + return a->name() < b->name(); + }); + + // Call user's handler function for each initializer. We store the initializer externally + // or within the model depending on the result returned by the handler function. + for (gsl::not_null initializer : initializers) { +#if !defined(DISABLE_SPARSE_TENSORS) + if (IsSparseInitializer(initializer->name())) { + // Sparse tensors are added to the ONNX file directly. + auto& sparse_initializer = *output_graph_proto.add_sparse_initializer(); + ORT_RETURN_IF_ERROR(utils::DenseTensorToSparseTensorProto(*initializer, ModelPath(), sparse_initializer)); + } else { +#endif + TensorProto* output_proto = output_graph_proto.add_initializer(); + + output_proto->set_name(initializer->name()); + output_proto->set_data_type(initializer->data_type()); + for (int i = 0; i != initializer->dims_size(); ++i) { + output_proto->add_dims(initializer->dims(i)); + } + output_proto->set_doc_string(initializer->doc_string()); + + OrtValue ort_value; + std::unique_ptr original_ext_data_info = nullptr; + + if (utils::HasExternalDataInFile(*initializer)) { + // Initializer has data in an external file. Load it into OrtValue (potentially via memory mapping). + ORT_RETURN_IF_ERROR(ExternalDataInfo::Create(initializer->external_data(), original_ext_data_info)); + ORT_RETURN_IF_ERROR(utils::GetExtDataFromTensorProto(Env::Default(), ModelPath(), *initializer, ort_value)); + } else { + // Initializer is either stored inline within the TensorProto or it is "external data in memory". + // Get an OrtValue (if already loaded by Graph) or copy into an OrtValue otherwise. + bool graph_has_ort_value = GetOrtValueInitializer(initializer->name(), ort_value, /*check_outer_scope*/ false); + if (!graph_has_ort_value) { + assert(!utils::HasExternalData(*initializer)); + ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), ModelPath(), *initializer, + CPUAllocator::DefaultInstance(), ort_value)); + } + } + + // Call the user's initializer handler function. If the user wants to store the initializer externally, + // the handler function will use OrtApi::CreateExternalInitializerInfo() to create a new + // OrtExternalInitializerInfo instance that indicates the location of the data. + OrtExternalInitializerInfo* new_external_info = nullptr; + Status status = ToStatusAndRelease(handle_initializer_func(state, initializer->name().c_str(), + &ort_value, + static_cast(original_ext_data_info.get()), + &new_external_info)); + + ORT_RETURN_IF(new_external_info != nullptr && + new_external_info == static_cast(original_ext_data_info.get()), + "User's OrtGetInitializerLocationFunc must not return the external_info parameter.", + "Return a copy instead."); + std::unique_ptr new_external_info_holder(new_external_info); // Take ownership + ORT_RETURN_IF_ERROR(status); + + if (new_external_info != nullptr) { + ExternalDataInfo::SetExternalLocationToProto(new_external_info->GetRelPath(), new_external_info->GetOffset(), + new_external_info->GetLength(), *output_proto); + } else { + const Tensor& tensor = ort_value.Get(); + output_proto->clear_data_location(); + utils::SetRawDataInTensorProto(*output_proto, tensor.DataRaw(), tensor.SizeInBytes()); + } +#if !defined(DISABLE_SPARSE_TENSORS) + } +#endif + } + + return Status::OK(); +} + +Status Graph::ToGraphProtoWithCustomInitializerHandling(OrtGetInitializerLocationFunc handle_initializer_func, + void* state, + /*out*/ ONNX_NAMESPACE::GraphProto& graph_proto) const { + ToGraphProtoInternal(graph_proto); + ORT_RETURN_IF_ERROR(ToGraphProtoWithCustomInitializerHandlingImpl(handle_initializer_func, state, graph_proto)); + return Status::OK(); +} + void Graph::ToGraphProtoInternal(ONNX_NAMESPACE::GraphProto& graph_proto) const { graph_proto_->clear_node(); graph_proto_->clear_input(); diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index eb5e1e89e2f9c..0ffbced51ee35 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -415,6 +415,25 @@ ModelProto Model::ToGraphProtoWithExternalInitializers(const std::filesystem::pa return result; } +common::Status Model::ToGraphProtoWithCustomInitializerHandling(OrtGetInitializerLocationFunc handle_initializer_func, + void* state, + /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) const { + model_proto = model_proto_; + + // Sync current model_metadata_ back to protobuf metadata_props + model_proto.clear_metadata_props(); + for (const auto& metadata : model_metadata_) { + const gsl::not_null prop{model_proto.add_metadata_props()}; + prop->set_key(metadata.first); + prop->set_value(metadata.second); + } + + const auto& graph = *graph_; + ORT_RETURN_IF_ERROR(graph.ToGraphProtoWithCustomInitializerHandling(handle_initializer_func, + state, *model_proto.mutable_graph())); + return Status::OK(); +} + Status Model::Load(std::istream& model_istream, ModelProto* p_model_proto) { if (!model_istream.good()) { return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid istream object."); diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h index e8722f6f5c0b2..c86aac44806bd 100644 --- a/onnxruntime/core/graph/model.h +++ b/onnxruntime/core/graph/model.h @@ -210,6 +210,18 @@ class Model { const std::filesystem::path& file_path, const ModelSavingOptions& model_saving_options) const; + /// + /// Serialize the Model to a onnx::ModelProto. Caller provides a function that determines where each initializer + /// is stored (i.e., either in an external file or within the model). + /// + /// Function called for every initializer. + /// Opaque user state passed to the handle_initializer_func. + /// Output parameter set to the serialized onnx::ModelProto. + /// A status indicating success or an error. + common::Status ToGraphProtoWithCustomInitializerHandling(OrtGetInitializerLocationFunc handle_initializer_func, + void* state, + /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) const; + static common::Status Save(Model& model, const PathString& file_path); static common::Status Save(Model& model, int fd); diff --git a/onnxruntime/core/session/compile_api.cc b/onnxruntime/core/session/compile_api.cc index 759773042debb..b9a54ea7104e1 100644 --- a/onnxruntime/core/session/compile_api.cc +++ b/onnxruntime/core/session/compile_api.cc @@ -64,7 +64,7 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetInputModelPath, API_IMPL_BEGIN #if !defined(ORT_MINIMAL_BUILD) auto model_compile_options = reinterpret_cast(ort_model_compile_options); - std::string model_path = PathToUTF8String(input_model_path); + std::filesystem::path model_path = input_model_path; if (model_path.empty()) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid input model: path string is empty"); @@ -113,7 +113,7 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelPath, #if !defined(ORT_MINIMAL_BUILD) auto model_compile_options = reinterpret_cast(ort_model_compile_options); - std::string model_path = PathToUTF8String(output_model_path); + std::filesystem::path model_path = output_model_path; if (model_path.empty()) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid output model path: path is empty"); } @@ -136,17 +136,18 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetEpContextBinaryInf #if !defined(ORT_MINIMAL_BUILD) auto model_compile_options = reinterpret_cast(ort_model_compile_options); - std::string output_dir = PathToUTF8String(output_directory); - if (output_dir.empty()) { + std::filesystem::path output_directory_path = output_directory; + if (output_directory_path.empty()) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid output directory: path is empty"); } - std::string model_name_str = ToUTF8String(model_name); - if (model_name_str.empty()) { + std::filesystem::path model_name_path = model_name; + if (model_name_path.empty()) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid model name: string is empty"); } - ORT_API_RETURN_IF_STATUS_NOT_OK(model_compile_options->SetEpContextBinaryInformation(output_dir, model_name_str)); + ORT_API_RETURN_IF_STATUS_NOT_OK(model_compile_options->SetEpContextBinaryInformation(output_directory_path, + model_name_path)); return nullptr; #else ORT_UNUSED_PARAMETER(ort_model_compile_options); @@ -163,7 +164,7 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelExterna size_t external_initializer_size_threshold) { API_IMPL_BEGIN #if !defined(ORT_MINIMAL_BUILD) - std::string initializers_file_path = PathToUTF8String(external_initializers_file_path); + std::filesystem::path initializers_file_path = external_initializers_file_path; if (initializers_file_path.empty()) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid external initializer file: path is empty"); } @@ -214,6 +215,50 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelBuffer, API_IMPL_END } +ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelWriteFunc, + _In_ OrtModelCompilationOptions* ort_model_compile_options, + _In_ OrtWriteBufferFunc write_func, _In_ void* state) { + API_IMPL_BEGIN +#if !defined(ORT_MINIMAL_BUILD) + auto model_compile_options = reinterpret_cast(ort_model_compile_options); + + if (write_func == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "OrtWriteBufferFunc function for output model is null"); + } + + model_compile_options->SetOutputModelWriteFunc(write_func, state); + return nullptr; +#else + ORT_UNUSED_PARAMETER(ort_model_compile_options); + ORT_UNUSED_PARAMETER(write_func); + ORT_UNUSED_PARAMETER(state); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Compile API is not supported in this build"); +#endif // !defined(ORT_MINIMAL_BUILD) + API_IMPL_END +} +ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc, + _In_ OrtModelCompilationOptions* ort_model_compile_options, + _In_ OrtGetInitializerLocationFunc get_initializer_location_func, _In_ void* state) { + API_IMPL_BEGIN +#if !defined(ORT_MINIMAL_BUILD) + auto model_compile_options = reinterpret_cast(ort_model_compile_options); + + if (get_initializer_location_func == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "OrtGetInitializerLocationFunc function for output model is null"); + } + + model_compile_options->SetOutputModelGetInitializerLocationFunc(get_initializer_location_func, state); + return nullptr; +#else + ORT_UNUSED_PARAMETER(ort_model_compile_options); + ORT_UNUSED_PARAMETER(get_initializer_location_func); + ORT_UNUSED_PARAMETER(state); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Compile API is not supported in this build"); +#endif // !defined(ORT_MINIMAL_BUILD) + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetEpContextEmbedMode, _In_ OrtModelCompilationOptions* ort_model_compile_options, bool embed_ep_context_in_model) { @@ -295,6 +340,8 @@ static constexpr OrtCompileApi ort_compile_api = { &OrtCompileAPI::ModelCompilationOptions_SetFlags, &OrtCompileAPI::ModelCompilationOptions_SetEpContextBinaryInformation, &OrtCompileAPI::ModelCompilationOptions_SetGraphOptimizationLevel, + &OrtCompileAPI::ModelCompilationOptions_SetOutputModelWriteFunc, + &OrtCompileAPI::ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc, }; // checks that we don't violate the rule that the functions must remain in the slots they were originally assigned diff --git a/onnxruntime/core/session/compile_api.h b/onnxruntime/core/session/compile_api.h index 51cf71cd6ec61..34fa06340a7f9 100644 --- a/onnxruntime/core/session/compile_api.h +++ b/onnxruntime/core/session/compile_api.h @@ -35,5 +35,11 @@ ORT_API_STATUS_IMPL(ModelCompilationOptions_SetEpContextBinaryInformation, _In_ ORT_API_STATUS_IMPL(ModelCompilationOptions_SetGraphOptimizationLevel, _In_ OrtModelCompilationOptions* model_compile_options, _In_ GraphOptimizationLevel graph_optimization_level); +ORT_API_STATUS_IMPL(ModelCompilationOptions_SetOutputModelWriteFunc, + _In_ OrtModelCompilationOptions* model_compile_options, + _In_ OrtWriteBufferFunc write_func, _In_ void* state); +ORT_API_STATUS_IMPL(ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc, + _In_ OrtModelCompilationOptions* model_compile_options, + _In_ OrtGetInitializerLocationFunc get_initializer_location_func, _In_ void* state); } // namespace OrtCompileAPI diff --git a/onnxruntime/core/session/model_compilation_options.cc b/onnxruntime/core/session/model_compilation_options.cc index 3ada35eeaff63..84f41771cb62b 100644 --- a/onnxruntime/core/session/model_compilation_options.cc +++ b/onnxruntime/core/session/model_compilation_options.cc @@ -7,8 +7,11 @@ #include #include #include +#include +#include "core/common/path_string.h" #include "core/framework/allocator.h" +#include "core/framework/ep_context_options.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/environment.h" @@ -22,7 +25,7 @@ ModelCompilationOptions::ModelCompilationOptions(const onnxruntime::Environment& // defaulting to kGenerateModel to support wider usage. session_options_.value.ep_context_gen_options.action_if_no_compiled_nodes = - EpContextModelGenerationOptions::ActionIfNoCompiledNodes::kGenerateModel; + epctx::ModelGenOptions::ActionIfNoCompiledNodes::kGenerateModel; // Shouldn't fail because the key/value strings are below the maximum string length limits in ConfigOptions. ORT_ENFORCE(session_options_.value.config_options.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1").IsOK()); @@ -31,7 +34,7 @@ ModelCompilationOptions::ModelCompilationOptions(const onnxruntime::Environment& session_options_.value.graph_optimization_level = TransformerLevel::Default; // L0: required transformers only } -void ModelCompilationOptions::SetInputModelPath(const std::string& input_model_path) { +void ModelCompilationOptions::SetInputModelPath(const std::filesystem::path& input_model_path) { ResetInputModelSettings(); input_model_path_ = input_model_path; } @@ -42,17 +45,16 @@ void ModelCompilationOptions::SetInputModelFromBuffer(const void* input_model_da input_model_data_size_ = input_model_data_size; } -Status ModelCompilationOptions::SetOutputModelPath(const std::string& output_model_path) { - ORT_RETURN_IF_ERROR(ResetOutputModelSettings()); - +Status ModelCompilationOptions::SetOutputModelPath(const std::filesystem::path& output_model_path) { ConfigOptions& config_options = session_options_.value.config_options; - EpContextModelGenerationOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options; + epctx::ModelGenOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options; + + ep_context_gen_options.output_model_location = output_model_path; - ep_context_gen_options.output_model_file_path = output_model_path; + std::string output_model_path_str = PathToUTF8String(output_model_path); - if (ep_context_gen_options.output_model_file_path.size() <= ConfigOptions::kMaxValueLength) { - Status status = config_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, - ep_context_gen_options.output_model_file_path.c_str()); + if (output_model_path_str.size() <= ConfigOptions::kMaxValueLength) { + Status status = config_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, output_model_path_str.c_str()); ORT_ENFORCE(status.IsOK()); // Should not fail because both key/value strings are below the min string lengths // required by ConfigOptions::AddConfigEntry(). } else { @@ -73,7 +75,7 @@ Status ModelCompilationOptions::SetOutputModelPath(const std::string& output_mod logging::LoggingManager* log_manager = env_.GetLoggingManager(); if (log_manager != nullptr && log_manager->HasDefaultLogger()) { const logging::Logger& logger = log_manager->DefaultLogger(); - LOGS(logger, WARNING) << "Output model path length (" << ep_context_gen_options.output_model_file_path.size() + LOGS(logger, WARNING) << "Output model path length (" << output_model_path_str.size() << ") exceeds limit of " << ConfigOptions::kMaxValueLength << " characters." << "ORT will still generate the expected output file, but EPs will see an empty " << "output model path in SessionOption's ConfigOptions."; @@ -82,40 +84,58 @@ Status ModelCompilationOptions::SetOutputModelPath(const std::string& output_mod return Status::OK(); } -void ModelCompilationOptions::SetOutputModelExternalInitializersFile(const std::string& external_initializers_path, - size_t external_initializer_size_threshold) { - session_options_.value.ep_context_gen_options.output_external_initializers_file_path = external_initializers_path; - session_options_.value.ep_context_gen_options.output_external_initializer_size_threshold = - external_initializer_size_threshold; +void ModelCompilationOptions::SetOutputModelExternalInitializersFile( + const std::filesystem::path& external_initializers_path, + size_t external_initializer_size_threshold) { + session_options_.value.ep_context_gen_options.initializers_location = epctx::ExternalInitializerFileInfo{ + external_initializers_path, + external_initializer_size_threshold, + }; } Status ModelCompilationOptions::SetOutputModelBuffer(onnxruntime::AllocatorPtr allocator, void** output_model_buffer_ptr, size_t* output_model_buffer_size_ptr) { - ORT_RETURN_IF_ERROR(ResetOutputModelSettings()); + session_options_.value.ep_context_gen_options.output_model_location = epctx::BufferHolder{ + output_model_buffer_ptr, + output_model_buffer_size_ptr, + std::move(allocator), + }; - session_options_.value.ep_context_gen_options.output_model_buffer_ptr = output_model_buffer_ptr; - session_options_.value.ep_context_gen_options.output_model_buffer_size_ptr = output_model_buffer_size_ptr; - session_options_.value.ep_context_gen_options.output_model_buffer_allocator = std::move(allocator); return Status::OK(); } -Status ModelCompilationOptions::SetEpContextBinaryInformation(const std::string& output_directory, - const std::string& model_name) { +void ModelCompilationOptions::SetOutputModelWriteFunc(OrtWriteBufferFunc write_func, void* state) { + session_options_.value.ep_context_gen_options.output_model_location = epctx::BufferWriteFuncHolder{ + write_func, + state, + }; +} + +void ModelCompilationOptions::SetOutputModelGetInitializerLocationFunc( + OrtGetInitializerLocationFunc get_initializer_location_func, void* state) { + session_options_.value.ep_context_gen_options.initializers_location = epctx::InitializerHandler{ + get_initializer_location_func, + state, + }; +} + +Status ModelCompilationOptions::SetEpContextBinaryInformation(const std::filesystem::path& output_directory, + const std::filesystem::path& model_name) { if (output_directory.empty() || model_name.empty()) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "output_dir or model_name is empty."); } - std::filesystem::path output_dir_path(output_directory); - if (output_dir_path.has_filename() && output_dir_path.extension() == "") { + if (output_directory.has_filename() && output_directory.extension() == "") { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "output_dir is not a valid directory."); } - std::filesystem::path ctx_model_path = output_directory / std::filesystem::path(model_name); + std::filesystem::path ctx_model_path = output_directory / model_name; + std::string ctx_model_path_str = PathToUTF8String(ctx_model_path); - if (ctx_model_path.string().size() <= ConfigOptions::kMaxValueLength) { + if (ctx_model_path_str.size() <= ConfigOptions::kMaxValueLength) { ORT_RETURN_IF_ERROR(session_options_.value.config_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, - ctx_model_path.string().c_str())); + ctx_model_path_str.c_str())); } else { logging::LoggingManager* log_manager = env_.GetLoggingManager(); if (log_manager != nullptr && log_manager->HasDefaultLogger()) { @@ -138,11 +158,11 @@ Status ModelCompilationOptions::SetEpContextEmbedMode(bool embed_ep_context_in_m } Status ModelCompilationOptions::SetFlags(uint32_t flags) { - EpContextModelGenerationOptions& options = session_options_.value.ep_context_gen_options; + epctx::ModelGenOptions& options = session_options_.value.ep_context_gen_options; options.error_if_output_file_exists = flags & OrtCompileApiFlags_ERROR_IF_OUTPUT_FILE_EXISTS; options.action_if_no_compiled_nodes = - (flags & OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED) ? EpContextModelGenerationOptions::ActionIfNoCompiledNodes::kReturnError - : EpContextModelGenerationOptions::ActionIfNoCompiledNodes::kGenerateModel; + (flags & OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED) ? epctx::ModelGenOptions::ActionIfNoCompiledNodes::kReturnError + : epctx::ModelGenOptions::ActionIfNoCompiledNodes::kGenerateModel; return Status::OK(); } @@ -154,7 +174,7 @@ bool ModelCompilationOptions::InputModelComesFromFile() const { return !input_model_path_.empty(); } -const std::string& ModelCompilationOptions::GetInputModelPath() const { +const std::filesystem::path& ModelCompilationOptions::GetInputModelPath() const { return input_model_path_; } @@ -200,77 +220,78 @@ Status ModelCompilationOptions::SetGraphOptimizationLevel(GraphOptimizationLevel return Status::OK(); } -Status ModelCompilationOptions::ResetOutputModelSettings() { - EpContextModelGenerationOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options; - ep_context_gen_options.output_model_file_path.clear(); - ep_context_gen_options.output_model_buffer_ptr = nullptr; - ep_context_gen_options.output_model_buffer_size_ptr = nullptr; - ep_context_gen_options.output_model_buffer_allocator = nullptr; - return Status::OK(); -} +Status ModelCompilationOptions::Check() const { + const ConfigOptions& config_options = session_options_.value.config_options; + + ORT_ENFORCE(session_options_.value.ep_context_gen_options.enable); + ORT_ENFORCE(config_options.GetConfigOrDefault(kOrtSessionOptionsDisableModelCompile, "0") == "0"); -Status ModelCompilationOptions::CheckInputModelSettings() const { - const bool comes_from_file = !input_model_path_.empty(); - const bool comes_from_memory = input_model_data_ != nullptr; + // Check input model settings. + const bool input_from_file = !input_model_path_.empty(); + const bool input_from_memory = input_model_data_ != nullptr; - if (!comes_from_file && !comes_from_memory) { + if (!input_from_file && !input_from_memory) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input model to compile must be loaded from either a file or a memory buffer"); } - if (comes_from_file && comes_from_memory) { + if (input_from_file && input_from_memory) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input model to compile must be loaded from either a file or a memory buffer, ", "but not both."); } - if (comes_from_file && !std::filesystem::exists(input_model_path_)) { + if (input_from_file && !std::filesystem::exists(input_model_path_)) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input model path does not exist: ", input_model_path_); } - if (comes_from_memory && input_model_data_size_ == 0) { + if (input_from_memory && input_model_data_size_ == 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Buffer for input model data has size 0"); } - return Status::OK(); -} - -Status ModelCompilationOptions::CheckOutputModelSettings() const { - const EpContextModelGenerationOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options; + // Check output model settings. + const epctx::ModelGenOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options; + bool has_no_output_model_location = std::holds_alternative( + ep_context_gen_options.output_model_location); - const bool explicit_writes_to_file = !ep_context_gen_options.output_model_file_path.empty(); - const bool writes_to_buffer = ep_context_gen_options.output_model_buffer_ptr != nullptr; - - if (!explicit_writes_to_file && !writes_to_buffer) { - // User did not specify an output file or an output buffer. We default to generating an output file - // with a name based on the input file name, so do not return an error. + if (has_no_output_model_location && input_from_file) { + // User did not specify an output file, an output buffer, or an output write function. We default to generating an + // output file with a name based on the input file name, so do not return an error. return Status::OK(); } - if (explicit_writes_to_file && writes_to_buffer) { + if (has_no_output_model_location) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Output model to compile must be saved either to a file or to a buffer, but not both."); + "Unable to generate an output model path: require an input model path if the location " + "of the output model (e.g., file, buffer, or stream) is not specified."); } - if (writes_to_buffer && ep_context_gen_options.output_model_buffer_size_ptr == nullptr) { + const epctx::BufferHolder* output_buffer_ptr = ep_context_gen_options.TryGetOutputModelBuffer(); + + if (output_buffer_ptr != nullptr && output_buffer_ptr->buffer_ptr == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid buffer configuration for output model: buffer pointer is null"); + } + + if (output_buffer_ptr != nullptr && output_buffer_ptr->buffer_size_ptr == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid buffer configuration for output model: size pointer is null"); } - if (writes_to_buffer && ep_context_gen_options.output_model_buffer_allocator == nullptr) { + if (output_buffer_ptr != nullptr && output_buffer_ptr->buffer_allocator == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid buffer configuration for output model: allocator is null"); } - return Status::OK(); -} + const epctx::BufferWriteFuncHolder* output_write_func_holder = ep_context_gen_options.TryGetOutputModelWriteFunc(); + + if (output_write_func_holder != nullptr && output_write_func_holder->write_func == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid buffer writing function for output model: function pointer is null"); + } -Status ModelCompilationOptions::Check() const { - ORT_ENFORCE(session_options_.value.ep_context_gen_options.enable); - ORT_ENFORCE(session_options_.value.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableModelCompile, "0") == "0"); - ORT_RETURN_IF_ERROR(CheckInputModelSettings()); - ORT_RETURN_IF_ERROR(CheckOutputModelSettings()); return Status::OK(); } + } // namespace onnxruntime #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/session/model_compilation_options.h b/onnxruntime/core/session/model_compilation_options.h index cd9091561af79..45323e6cb13c5 100644 --- a/onnxruntime/core/session/model_compilation_options.h +++ b/onnxruntime/core/session/model_compilation_options.h @@ -4,6 +4,7 @@ #if !defined(ORT_MINIMAL_BUILD) #pragma once +#include #include #include #include "core/common/status.h" @@ -34,7 +35,7 @@ class ModelCompilationOptions { /// Overrides any previous call to SetInputModelPath() or SetInputModelFromBuffer(). /// /// The input model's path - void SetInputModelPath(const std::string& input_model_path); + void SetInputModelPath(const std::filesystem::path& input_model_path); /// /// Sets the buffer that stores the input ONNX model to compile. @@ -50,7 +51,7 @@ class ModelCompilationOptions { /// /// /// Status indicating potential error - Status SetOutputModelPath(const std::string& output_model_path); + Status SetOutputModelPath(const std::filesystem::path& output_model_path); /// /// Sets the file path to the file that will store external ONNX initializers for the compiled model. @@ -58,7 +59,7 @@ class ModelCompilationOptions { /// /// Path to the external initializers file to generate /// Initializers that exceed this threshold are external - void SetOutputModelExternalInitializersFile(const std::string& external_initializers_path, + void SetOutputModelExternalInitializersFile(const std::filesystem::path& external_initializers_path, size_t external_initializer_size_threshold); /// @@ -72,6 +73,21 @@ class ModelCompilationOptions { Status SetOutputModelBuffer(onnxruntime::AllocatorPtr allocator, void** output_model_buffer_ptr, size_t* output_model_buffer_size_ptr); + /// + /// Sets an output stream (write function + state) used to write out the compiled model bytes. + /// + /// Write function + /// The user's state + void SetOutputModelWriteFunc(OrtWriteBufferFunc write_func, void* state); + + /// + /// Sets a user-provided function to handle serialization of ONNX initializers. + /// + /// The user-provided function called for every initializer + /// The user's state. + void SetOutputModelGetInitializerLocationFunc(OrtGetInitializerLocationFunc get_initializer_location_func, + void* state); + /// /// Sets information relate to EP context binary file. /// EP use this information to decide the location and context binary file name. @@ -80,7 +96,8 @@ class ModelCompilationOptions { /// The folder path to the generated context binary file /// Model name used to decide the context binary file name: [model_name]_[ep].bin /// Status indicating potential error - Status SetEpContextBinaryInformation(const std::string& output_directory, const std::string& model_name); + Status SetEpContextBinaryInformation(const std::filesystem::path& output_directory, + const std::filesystem::path& model_name); /// /// Enables or disables the embedding of EPContext binary data into the `ep_cache_context` attribute of EPContext @@ -107,7 +124,7 @@ class ModelCompilationOptions { /// Returns the file path to the input ONNX model. /// /// input model's path - const std::string& GetInputModelPath() const; + const std::filesystem::path& GetInputModelPath() const; /// /// Returns true if the input model is read from a file. @@ -144,13 +161,10 @@ class ModelCompilationOptions { private: void ResetInputModelSettings(); - Status ResetOutputModelSettings(); - Status CheckInputModelSettings() const; - Status CheckOutputModelSettings() const; const onnxruntime::Environment& env_; OrtSessionOptions session_options_; - std::string input_model_path_; + std::filesystem::path input_model_path_; const void* input_model_data_ = nullptr; size_t input_model_data_size_ = 0; }; diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index f3e2a8ce7ba7b..36f7f1f60c36e 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2538,6 +2538,23 @@ ORT_API(void, OrtApis::ReleaseExternalInitializerInfo, _Frees_ptr_opt_ OrtExtern delete static_cast(info); } +ORT_API_STATUS_IMPL(OrtApis::CreateExternalInitializerInfo, _In_ const ORTCHAR_T* filepath, + _In_ int64_t file_offset, _In_ size_t byte_size, _Outptr_ OrtExternalInitializerInfo** out) { + API_IMPL_BEGIN +#if !defined(ORT_MINIMAL_BUILD) + auto ext_data_info = std::make_unique(filepath, file_offset, byte_size); + *out = ext_data_info.release(); + return nullptr; +#else + ORT_UNUSED_PARAMETER(filepath); + ORT_UNUSED_PARAMETER(file_offset); + ORT_UNUSED_PARAMETER(byte_size); + ORT_UNUSED_PARAMETER(out); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "CreateExternalInitializerInfo() is not supported in this build."); +#endif + API_IMPL_END +} + ORT_API(const ORTCHAR_T*, OrtApis::ExternalInitializerInfo_GetFilePath, _In_ const OrtExternalInitializerInfo* info) { return info->GetRelPath().c_str(); } @@ -4202,6 +4219,7 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::Graph_GetModelMetadata, &OrtApis::GetModelCompatibilityForEpDevices, + &OrtApis::CreateExternalInitializerInfo, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 6dc4cf9d195cc..78616c7b3973e 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -700,6 +700,8 @@ ORT_API_STATUS_IMPL(Node_GetEpName, _In_ const OrtNode* node, _Outptr_result_may // OrtExternalInitializerInfo ORT_API(void, ReleaseExternalInitializerInfo, _Frees_ptr_opt_ OrtExternalInitializerInfo* info); +ORT_API_STATUS_IMPL(CreateExternalInitializerInfo, _In_ const ORTCHAR_T* filepath, _In_ int64_t file_offset, + _In_ size_t byte_size, _Outptr_ OrtExternalInitializerInfo** out); ORT_API(const ORTCHAR_T*, ExternalInitializerInfo_GetFilePath, _In_ const OrtExternalInitializerInfo* info); ORT_API(int64_t, ExternalInitializerInfo_GetFileOffset, _In_ const OrtExternalInitializerInfo* info); ORT_API(size_t, ExternalInitializerInfo_GetByteSize, _In_ const OrtExternalInitializerInfo* info); diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index 7da7fabb15b15..444027692903c 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -136,13 +136,11 @@ static OrtStatus* CreateSessionAndLoadModelImpl(_In_ const OrtSessionOptions* op // If ep.context_enable is set, then ep.context_file_path is expected, otherwise ORT don't know where to generate the _ctx.onnx file if (options && model_path == nullptr) { - EpContextModelGenerationOptions ep_ctx_gen_options = options->value.GetEpContextGenerationOptions(); + epctx::ModelGenOptions ep_ctx_gen_options = options->value.GetEpContextGenerationOptions(); // This is checked by the OrtCompileApi's CompileModel() function, but we check again here in case // the user used the older SessionOptions' configuration entries to generate a compiled model. - if (ep_ctx_gen_options.enable && - ep_ctx_gen_options.output_model_file_path.empty() && - ep_ctx_gen_options.output_model_buffer_ptr == nullptr) { + if (ep_ctx_gen_options.enable && !ep_ctx_gen_options.HasOutputModelLocation()) { return OrtApis::CreateStatus(ORT_FAIL, "Inference session was configured with EPContext model generation enabled but " "without a valid location (e.g., file or buffer) for the output model. " @@ -383,7 +381,7 @@ Status CompileModel(const Environment& env, const ModelCompilationOptions& model const OrtSessionOptions* session_options = &model_compile_options.GetSessionOptions(); if (model_compile_options.InputModelComesFromFile()) { - PathString input_model_path = ToPathString(model_compile_options.GetInputModelPath()); + const std::filesystem::path& input_model_path = model_compile_options.GetInputModelPath(); ORT_RETURN_IF_ERROR(ToStatusAndRelease(CreateSessionAndLoadModelImpl(session_options, env, input_model_path.c_str(), nullptr, 0, session))); diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index 8c8ba214eb714..35abad5760c32 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -9,7 +9,7 @@ import os import typing import warnings -from collections.abc import Sequence +from collections.abc import Callable, Sequence from typing import Any from onnxruntime.capi import _pybind_state as C @@ -620,6 +620,36 @@ def _register_ep_custom_ops(self, session_options, providers, provider_options, C.register_nv_tensorrt_rtx_plugins_as_custom_ops(session_options, providers[i][1]) +def make_get_initializer_location_func_wrapper( + get_initializer_location_func: GetInitializerLocationFunc, +) -> GetInitializerLocationWrapperFunc: + """ + Wraps a user's "get initializer location" function. The returned wrapper function adheres to the + signature expected by ORT. + + Need this wrapper to: + - Convert the `initializer_value` parameter from `C.OrtValue` to `onnxruntime.OrtValue`, which is more + convenient for the user's function to use. + - Allow the user's function to return the original `external_info` parameter (this wrapper makes a copy) + """ + + def get_initializer_location_func_wrapper( + initializer_name: str, + initializer_value: C.OrtValue, + external_info: C.OrtExternalInitializerInfo | None, + ) -> C.OrtExternalInitializerInfo | None: + ret_val: C.OrtExternalInitializerInfo | None = get_initializer_location_func( + initializer_name, OrtValue(initializer_value), external_info + ) + if ret_val is not None and ret_val == external_info: + # User returned `external_info` (const and owned by ORT). ORT expects the returned value to be + # a new instance (that it deletes), so make a copy. + ret_val = C.OrtExternalInitializerInfo(ret_val.filepath, ret_val.file_offset, ret_val.byte_size) + return ret_val + + return get_initializer_location_func_wrapper + + class ModelCompiler: """ This class is used to compile an ONNX model. A compiled ONNX model has EPContext nodes that each @@ -648,6 +678,7 @@ def __init__( external_initializers_size_threshold: int = 1024, flags: int = C.OrtCompileApiFlags.NONE, graph_optimization_level: C.GraphOptimizationLevel = C.GraphOptimizationLevel.ORT_DISABLE_ALL, + get_initializer_location_func: GetInitializerLocationFunc | None = None, ): """ Creates a ModelCompiler instance. @@ -666,6 +697,25 @@ def __init__( flags in onnxruntime.OrtCompileApiFlags. :param graph_optimization_level: The graph optimization level. Defaults to onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL. + :param get_initializer_location_func: Optional function called for every initializer to allow user to specify + whether an initializer should be stored within the model or externally. Example: + ``` + def get_initializer_location( + initializer_name: str, + initializer_value: onnxrt.OrtValue, + external_info: onnxrt.OrtExternalInitializerInfo | None, + ) -> onnxrt.OrtExternalInitializerInfo | None: + byte_size = initializer_value.tensor_size_in_bytes() + + if byte_size < 64: + return None # Store small initializer within compiled model. + + # Else, write initializer to new external file. + value_np = initializer_value.numpy() + file_offset = ext_init_file.tell() + ext_init_file.write(value_np.tobytes()) + return onnxrt.OrtExternalInitializerInfo(initializer_file_path, file_offset, byte_size) + ``` """ input_model_path: str | os.PathLike | None = None input_model_bytes: bytes | None = None @@ -688,6 +738,18 @@ def __init__( else: external_initializers_file_path = "" + if get_initializer_location_func is not None: + if external_initializers_file_path: + raise ValueError( + "Cannot initialize ModelCompiler with both `external_initializers_file_path` " + "and `get_initializer_location_func`" + ) + self.get_initializer_location_func_wrapper = make_get_initializer_location_func_wrapper( + get_initializer_location_func + ) + else: + self.get_initializer_location_func_wrapper = None + if input_model_path: self._model_compiler = C.ModelCompiler( sess_options, @@ -698,6 +760,7 @@ def __init__( external_initializers_size_threshold, flags, graph_optimization_level, + self.get_initializer_location_func_wrapper, ) else: self._model_compiler = C.ModelCompiler( @@ -709,6 +772,7 @@ def __init__( external_initializers_size_threshold, flags, graph_optimization_level, + self.get_initializer_location_func_wrapper, ) def compile_to_file(self, output_model_path: str | None = None): @@ -738,6 +802,14 @@ def compile_to_bytes(self) -> bytes: """ return self._model_compiler.compile_to_bytes() + def compile_to_stream(self, write_function: Callable[[bytes], None]): + """ + Compiles the input model and writes the serialized ONNX bytes to a stream using the provided write function. + Raises an 'InvalidArgument' exception if the compilation options are invalid. + :param write_function: A callable that accepts a bytes buffer to write. + """ + self._model_compiler.compile_to_stream(write_function) + class IOBinding: """ @@ -1298,3 +1370,14 @@ def device_name(self) -> str: Returns the name of the device where the SparseTensor data buffers reside e.g. cpu, cuda """ return self._tensor.device_name().lower() + + +# Type hint for user-specified function that allows the user to specify initializer locations when compiling a model. +GetInitializerLocationFunc = Callable[ + [str, OrtValue, C.OrtExternalInitializerInfo | None], C.OrtExternalInitializerInfo | None +] + +# Type hint that adheres to the signature expected by ORT. +GetInitializerLocationWrapperFunc = Callable[ + [str, C.OrtValue, C.OrtExternalInitializerInfo | None], C.OrtExternalInitializerInfo | None +] diff --git a/onnxruntime/python/onnxruntime_pybind_model_compiler.cc b/onnxruntime/python/onnxruntime_pybind_model_compiler.cc index 69929cb68a775..6ff252b5d1353 100644 --- a/onnxruntime/python/onnxruntime_pybind_model_compiler.cc +++ b/onnxruntime/python/onnxruntime_pybind_model_compiler.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. +#if !defined(ORT_MINIMAL_BUILD) #include "python/onnxruntime_pybind_model_compiler.h" #include @@ -8,11 +9,56 @@ #include #include "core/common/common.h" #include "core/framework/error_code_helper.h" +#include "core/graph/abi_graph_types.h" #include "core/session/utils.h" namespace onnxruntime { namespace python { +/// +/// This function is called by ORT to allow the user to handle where every initializer is stored +/// (i.e., externally or internally). This function wraps (and calls) the actual Python function +/// provided by the user. +/// +/// Opaque state that holds a pointer to the user's Python function. +/// The name of the initializer to handle. +/// The OrtValue with the initializer's data, type, and shape. +/// The original external location of the initializer, if any. May be null. +/// Output parameter set to the initializer's new external location. Function may +/// return NULL if the initializer should be stored within the compiled ONNX model. +/// A status indicating success or an error. +static OrtStatus* ORT_API_CALL PyGetInitializerLocationFuncWrapper( + void* state, + const char* initializer_name, + const OrtValue* initializer_value, + const OrtExternalInitializerInfo* external_info, + /*out*/ OrtExternalInitializerInfo** new_external_info) { + PyGetInitializerLocationFunc* py_func = reinterpret_cast(state); + OrtStatus* status = nullptr; + std::shared_ptr py_new_external_info = nullptr; + + // Call the Python function and convert any exceptions to a status. + ORT_TRY { + py_new_external_info = (*py_func)(initializer_name, *initializer_value, external_info); + } + ORT_CATCH(const std::exception& e) { + ORT_HANDLE_EXCEPTION([&]() { + status = ToOrtStatus(ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, e.what())); + }); + } + + if (py_new_external_info) { + // ORT expects to take ownership of the new external info, so make a copy because other Python code + // may be holding a reference to the `py_new_external_info`. + auto py_result_copy = std::make_unique(*py_new_external_info.get()); + *new_external_info = py_result_copy.release(); + } else { + *new_external_info = nullptr; + } + + return status; +} + onnxruntime::Status PyModelCompiler::Create(/*out*/ std::unique_ptr& out, onnxruntime::Environment& env, const PySessionOptions& sess_options, @@ -21,8 +67,10 @@ onnxruntime::Status PyModelCompiler::Create(/*out*/ std::unique_ptr(env, sess_options, PrivateConstructorTag{}); + GraphOptimizationLevel graph_optimization_level, + const PyGetInitializerLocationFunc& py_get_initializer_location_func) { + auto model_compiler = std::make_unique(env, sess_options, py_get_initializer_location_func, + PrivateConstructorTag{}); ModelCompilationOptions& compile_options = model_compiler->model_compile_options_; if (input_model_is_path) { @@ -46,6 +94,12 @@ onnxruntime::Status PyModelCompiler::Create(/*out*/ std::unique_ptrpy_get_initializer_location_func_) { + compile_options.SetOutputModelGetInitializerLocationFunc( + PyGetInitializerLocationFuncWrapper, + reinterpret_cast(&model_compiler->py_get_initializer_location_func_)); + } + out = std::move(model_compiler); return Status::OK(); } @@ -80,9 +134,47 @@ onnxruntime::Status PyModelCompiler::CompileToBytes(std::string& output_buffer) return Status::OK(); } +/// +/// Function called by ORT to allow the user to write out the compiled ONNX model bytes to a custom output stream. +/// This function wraps (and calls) the actual Python function provided by the user. +/// +/// Opaque state that holds a pointer to the user's Python function. +/// The buffer to write out. Contains a portion of the compiled ONNX model's bytes. +/// The number of bytes in the buffer. +/// A status indicating success or an error. +static OrtStatus* ORT_API_CALL PyOutStreamWriteFuncWrapper(void* stream_state, const void* buffer, + size_t buffer_num_bytes) { + PyOutStreamWriteFunc* py_write_func = reinterpret_cast(stream_state); + OrtStatus* status = nullptr; + + // Call the Python write function and convert any exceptions to a status. + ORT_TRY { + pybind11::bytes py_bytes(reinterpret_cast(buffer), buffer_num_bytes); + (*py_write_func)(py_bytes); + } + ORT_CATCH(const std::exception& e) { + ORT_HANDLE_EXCEPTION([&]() { + status = ToOrtStatus(ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, e.what())); + }); + } + + return status; +} + +onnxruntime::Status PyModelCompiler::CompileToOutStream(PyOutStreamWriteFunc& write_func) { + model_compile_options_.SetOutputModelWriteFunc(PyOutStreamWriteFuncWrapper, + reinterpret_cast(&write_func)); + ORT_RETURN_IF_ERROR(onnxruntime::CompileModel(env_, model_compile_options_)); + return Status::OK(); +} + PyModelCompiler::PyModelCompiler(onnxruntime::Environment& env, const PySessionOptions& sess_options, + const PyGetInitializerLocationFunc& py_get_initializer_location_func, PrivateConstructorTag) - : env_(env), model_compile_options_(env, sess_options) { + : env_(env), + model_compile_options_(env, sess_options), + py_get_initializer_location_func_(py_get_initializer_location_func) { } } // namespace python } // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/python/onnxruntime_pybind_model_compiler.h b/onnxruntime/python/onnxruntime_pybind_model_compiler.h index d770dbf65cc10..957350accdba2 100644 --- a/onnxruntime/python/onnxruntime_pybind_model_compiler.h +++ b/onnxruntime/python/onnxruntime_pybind_model_compiler.h @@ -3,7 +3,6 @@ // Licensed under the MIT License. #pragma once -#if !defined(ORT_MINIMAL_BUILD) #include #include #include "core/common/status.h" @@ -14,11 +13,24 @@ namespace onnxruntime { class Environment; namespace python { +// Type of the function provided by Python code that is called by ORT to write out the compiled model. +using PyOutStreamWriteFunc = std::function; + +// Type of the function provided by Python code that is called by ORT to handle every initializer. +using PyGetInitializerLocationFunc = std::function( + const std::string& initializer_name, + const OrtValue& initializer_value, + const OrtExternalInitializerInfo* external_info)>; + /// /// Class exposed to Python that enables compiling ONNX models. /// Internally wraps a onnxruntime::ModelCompilationOptions that stores and validates settings. /// class PyModelCompiler { +#if defined(ORT_MINIMAL_BUILD) + public: + bool not_defined_in_this_build{}; // Prevent empty class warning. +#else private: // private tag to pass to constructor to ensure that constructor cannot be directly called externally struct PrivateConstructorTag {}; @@ -40,6 +52,7 @@ class PyModelCompiler { /// Flags from OrtCompileApiFlags /// Optimization level for graph transformations on the model. /// Defaults to ORT_DISABLE_ALL to allow EP to get the original loaded model. + /// User's function to handle saving of initializers. /// A Status indicating error or success. static onnxruntime::Status Create(/*out*/ std::unique_ptr& out, onnxruntime::Environment& env, @@ -49,11 +62,13 @@ class PyModelCompiler { const std::string& external_initializers_file_path = {}, size_t external_initializers_size_threshold = 1024, uint32_t flags = 0, - GraphOptimizationLevel graph_opt_level = GraphOptimizationLevel::ORT_DISABLE_ALL); + GraphOptimizationLevel graph_opt_level = GraphOptimizationLevel::ORT_DISABLE_ALL, + const PyGetInitializerLocationFunc& py_get_initializer_location_func = nullptr); // Note: Creation should be done via Create(). This constructor is public so that it can be called from // std::make_shared(). PyModelCompiler(onnxruntime::Environment& env, const PySessionOptions& sess_options, + const PyGetInitializerLocationFunc& py_get_initializer_location_func, PrivateConstructorTag); /// @@ -73,11 +88,19 @@ class PyModelCompiler { /// A Status indicating error or success. onnxruntime::Status CompileToBytes(std::string& output_buffer); + /// + /// Compiles the input model and writes the result into the provided output stream (write functor). + /// + /// Write functor that encapsulates the stream's state. + /// A Status indicating error or success. + onnxruntime::Status CompileToOutStream(PyOutStreamWriteFunc& write_func); + private: onnxruntime::Environment& env_; onnxruntime::ModelCompilationOptions model_compile_options_; std::string input_model_bytes_; + PyGetInitializerLocationFunc py_get_initializer_location_func_; +#endif // defined(ORT_MINIMAL_BUILD) }; } // namespace python } // namespace onnxruntime -#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 27c76f7f5c482..e370518b1fffb 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -6,11 +6,8 @@ #include #include "python/onnxruntime_pybind_exceptions.h" #include "python/onnxruntime_pybind_mlvalue.h" -#include "python/onnxruntime_pybind_state_common.h" - -#if !defined(ORT_MINIMAL_BUILD) #include "python/onnxruntime_pybind_model_compiler.h" -#endif // !defined(ORT_MINIMAL_BUILD) +#include "python/onnxruntime_pybind_state_common.h" #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION #define PY_ARRAY_UNIQUE_SYMBOL onnxruntime_python_ARRAY_API @@ -45,6 +42,7 @@ #include "core/session/lora_adapters.h" #if !defined(ORT_MINIMAL_BUILD) +#include "core/graph/abi_graph_types.h" #include "core/session/abi_devices.h" #include "core/session/plugin_ep/ep_factory_internal.h" #include "core/session/provider_policy_context.h" @@ -2730,6 +2728,35 @@ including arg name, arg type (contains both type and shape).)pbdoc") .value("kSameAsRequested", onnxruntime::ArenaExtendStrategy::kSameAsRequested) .export_values(); + // Must use a std::shared_ptr to hold OrtExternalInitializerInfo because the same instances is passed + // between C++ and Python (and Python cannot transfer ownership to C++). + py::class_> ort_external_initializer_info_binding( + m, "OrtExternalInitializerInfo", + R"pbdoc(Location information for initializer data stored in an external file)pbdoc"); + ort_external_initializer_info_binding + .def(py::init([](const std::basic_string& filepath, int64_t file_offset, size_t byte_size) { +#if !defined(ORT_MINIMAL_BUILD) + return std::make_shared(filepath, file_offset, byte_size); +#else + ORT_UNUSED_PARAMETER(filepath); + ORT_UNUSED_PARAMETER(file_offset); + ORT_UNUSED_PARAMETER(byte_size); + ORT_THROW("OrtExternalInitializerInfo creation is not supported in this build"); +#endif + })) + .def_property_readonly( + "filepath", + [](OrtExternalInitializerInfo* info) -> std::basic_string { return info->GetRelPath(); }, + R"pbdoc(The relative path to the file in which initializer data is stored.)pbdoc") + .def_property_readonly( + "file_offset", + [](OrtExternalInitializerInfo* info) -> int64_t { return info->GetOffset(); }, + R"pbdoc(The file byte offset where the initializer data is stored.)pbdoc") + .def_property_readonly( + "byte_size", + [](OrtExternalInitializerInfo* info) -> size_t { return info->GetLength(); }, + R"pbdoc(The byte size of the initializer data in the file.)pbdoc"); + py::enum_(m, "OrtCompileApiFlags", py::arithmetic()) .value("NONE", OrtCompileApiFlags_NONE) .value("ERROR_IF_NO_NODES_COMPILED", OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED) @@ -2744,7 +2771,8 @@ including arg name, arg type (contains both type and shape).)pbdoc") std::string external_initializers_file_path = {}, size_t external_initializers_size_threshold = 1024, uint32_t flags = OrtCompileApiFlags_NONE, - GraphOptimizationLevel graph_optimization_level = GraphOptimizationLevel::ORT_DISABLE_ALL) { + GraphOptimizationLevel graph_optimization_level = GraphOptimizationLevel::ORT_DISABLE_ALL, + const PyGetInitializerLocationFunc& py_get_initializer_location_func = nullptr) { #if !defined(ORT_MINIMAL_BUILD) std::unique_ptr result; OrtPybindThrowIfError(PyModelCompiler::Create(result, GetEnv(), sess_options, @@ -2752,7 +2780,8 @@ including arg name, arg type (contains both type and shape).)pbdoc") embed_compiled_data_into_model, external_initializers_file_path, external_initializers_size_threshold, - flags, graph_optimization_level)); + flags, graph_optimization_level, + py_get_initializer_location_func)); return result; #else ORT_UNUSED_PARAMETER(sess_options); @@ -2763,6 +2792,7 @@ including arg name, arg type (contains both type and shape).)pbdoc") ORT_UNUSED_PARAMETER(external_initializers_size_threshold); ORT_UNUSED_PARAMETER(flags); ORT_UNUSED_PARAMETER(graph_optimization_level); + ORT_UNUSED_PARAMETER(py_get_initializer_location_func); ORT_THROW("Compile API is not supported in this build."); #endif })) @@ -2790,7 +2820,19 @@ including arg name, arg type (contains both type and shape).)pbdoc") ORT_THROW("Compile API is not supported in this build."); #endif }, - R"pbdoc(Compile an ONNX model into a buffer.)pbdoc"); + R"pbdoc(Compile an ONNX model into a buffer.)pbdoc") + .def( + "compile_to_stream", + [](PyModelCompiler* model_compiler, PyOutStreamWriteFunc& py_stream_write_func) { +#if !defined(ORT_MINIMAL_BUILD) + OrtPybindThrowIfError(model_compiler->CompileToOutStream(py_stream_write_func)); +#else + ORT_UNUSED_PARAMETER(model_compiler); + ORT_UNUSED_PARAMETER(py_stream_write_func); + ORT_THROW("Compile API is not supported in this build."); +#endif + }, + R"pbdoc(Compile an ONNX model into an output stream using the provided write functor.)pbdoc"); } bool InitArray() { diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index 6ad21fa9f5cf5..a9d6273ae2f20 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -11,6 +11,7 @@ #include "core/framework/kernel_registry.h" #include "core/framework/op_kernel.h" #include "core/framework/bfc_arena.h" +#include "core/framework/ep_context_options.h" #include "core/framework/session_state.h" #include "core/graph/graph_utils.h" #include "core/graph/graph_viewer.h" @@ -504,7 +505,7 @@ void LoadWithResourceAwarePartitioning(const ORTCHAR_T* model_path, ASSERT_STATUS_OK( partitioner.Partition(graph, session_state.GetMutableFuncMgr(), transform_layout_fn, sess_options.config_options, default_logger, GraphPartitioner::Mode::kNormal, - EpContextModelGenerationOptions{}, + epctx::ModelGenOptions{}, debug_graph_fn)); verifier_fn(graph); diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index a42a56492b04a..1c8cc6f78fe63 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -2077,6 +2077,278 @@ TEST_F(QnnHTPBackendTests, QnnEpDynamicOptions) { EXPECT_STREQ("Unsupported EP Dynamic Option", e.what()); } } + +// Implementation of OrtOutStreamWriteFunc that writes the compiled model to a file. +static OrtStatus* ORT_API_CALL TestWriteToStream(void* stream_state, const void* buffer, size_t buffer_num_bytes) { + std::ofstream* outfile = reinterpret_cast(stream_state); + outfile->write(reinterpret_cast(buffer), buffer_num_bytes); + return nullptr; // No error +} + +// Implementation of OrtOutStreamWriteFunc that directly returns an OrtStatus indicating an error. +static OrtStatus* ORT_API_CALL ReturnStatusFromStream(void* stream_state, const void* buffer, size_t buffer_num_bytes) { + ORT_UNUSED_PARAMETER(stream_state); + ORT_UNUSED_PARAMETER(buffer); + ORT_UNUSED_PARAMETER(buffer_num_bytes); + return Ort::GetApi().CreateStatus(ORT_FAIL, "Error from OrtOutStreamWriteFunc callback"); +} + +// Test using the CompileModel() API with settings: +// - input model comes from a file +// - write output model to custom write stream +TEST_F(QnnHTPBackendTests, CompileApi_InputFile_WriteOutputModelBytes) { + const ORTCHAR_T* input_model_file = ORT_TSTR("./compileapi_inputfile_writeoutputmodelbytes.onnx"); + std::filesystem::remove(input_model_file); + + // Create a test model and save it to a file. + TestModel test_model; + CreateTestModel(BuildGraphWithQAndNonQ(false), 21, logging::Severity::kERROR, test_model); + ASSERT_STATUS_OK(test_model.Save(input_model_file)); + + // Initialize session options with QNN EP + Ort::SessionOptions so; + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + provider_options["offload_graph_io_quantization"] = "0"; + so.AppendExecutionProvider("QNN", provider_options); + + const ORTCHAR_T* output_model_file = ORT_TSTR("compileapi_inputfile_writeoutputmodelbytes_ctx.onnx"); + std::filesystem::remove(output_model_file); + + // Open an output file. Test will incrementally write the output model to file + // via calls to our OrtOutStreamWriteFunc callback. + ASSERT_FALSE(std::filesystem::exists(output_model_file)); + std::ofstream outfile(output_model_file, std::ios::binary); + + // Create model compilation options from the session options. + Ort::ModelCompilationOptions compile_options(*ort_env, so); + compile_options.SetInputModelPath(input_model_file); + compile_options.SetOutputModelWriteFunc(TestWriteToStream, reinterpret_cast(&outfile)); + compile_options.SetEpContextEmbedMode(true); + compile_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC); + + // Compile the model. + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + ASSERT_TRUE(status.IsOK()) << status.GetErrorMessage(); + outfile.flush(); + outfile.close(); + + // Check that the compiled model has the expected number of EPContext nodes. + ASSERT_TRUE(std::filesystem::exists(output_model_file)); + CheckEpContextNodeCounts(output_model_file, 2, 2); +} + +// Tests using an OrtOutStreamFunc function that returns an error. +TEST_F(QnnHTPBackendTests, CompileApi_OutputStream_ReturnStatus) { + // Create a test model (in memory). + TestModel test_model; + CreateTestModel(BuildGraphWithQAndNonQ(false), 21, logging::Severity::kERROR, test_model); + std::string model_data = test_model.Serialize(); + + // Initialize session options with QNN EP + Ort::SessionOptions so; + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + provider_options["offload_graph_io_quantization"] = "0"; + so.AppendExecutionProvider("QNN", provider_options); + + // Create model compilation options from the session options. + Ort::ModelCompilationOptions compile_options(*ort_env, so); + compile_options.SetInputModelFromBuffer(reinterpret_cast(model_data.data()), model_data.size()); + compile_options.SetOutputModelWriteFunc(ReturnStatusFromStream, nullptr); // Set output stream that returns error + compile_options.SetEpContextEmbedMode(true); + + // Compile the model. Expect a specific error status. + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + ASSERT_FALSE(status.IsOK()); + EXPECT_EQ(status.GetErrorCode(), ORT_FAIL); + EXPECT_EQ(status.GetErrorMessage(), "Error from OrtOutStreamWriteFunc callback"); +} + +struct CustomInitializerHandlerState { + const ORTCHAR_T* external_file_path = nullptr; + std::ofstream* outfile = nullptr; +}; + +static OrtStatus* ORT_API_CALL TestHandleInitializerDataFunc(void* state, + const char* initializer_name, + const OrtValue* c_initializer_value, + const OrtExternalInitializerInfo* /*c_external_info*/, + OrtExternalInitializerInfo** c_new_external_info) { + Ort::Status final_status{nullptr}; + + ORT_TRY { + CustomInitializerHandlerState* custom_state = reinterpret_cast(state); + + if (std::string("constant") == initializer_name) { + // Keep a specific initializer in the model just to test both scenarios. + // A real implementation may check the byte size and keep small initializers in the model. + *c_new_external_info = nullptr; + return nullptr; + } + + // + // Store other initializers in an external file. + // + Ort::ConstValue value{c_initializer_value}; + size_t byte_size = value.GetTensorSizeInBytes(); + int64_t offset = custom_state->outfile->tellp(); + const ORTCHAR_T* location = custom_state->external_file_path; + + custom_state->outfile->write(static_cast(value.GetTensorRawData()), byte_size); + custom_state->outfile->flush(); + + // Provide caller (ORT) with the new external info. + Ort::ExternalInitializerInfo new_external_info{nullptr}; + if (Ort::Status status = Ort::ExternalInitializerInfo::Create(location, offset, byte_size, new_external_info); + !status.IsOK()) { + return status.release(); + } + + *c_new_external_info = new_external_info.release(); + } + ORT_CATCH(const Ort::Exception& ex) { + ORT_HANDLE_EXCEPTION(([&ex, &final_status]() { + final_status = Ort::Status{ex}; + })); + } + ORT_CATCH(const std::exception& ex) { + ORT_HANDLE_EXCEPTION(([&ex, &final_status]() { + final_status = Ort::Status(ex.what(), ORT_FAIL); + })); + } + + return final_status.release(); +} + +// Test using the CompileModel() API with settings: +// - input model comes from a file +// - write output model to a file +// - Use callback to specify where each initializer is stored (i.e., external file or within model). +TEST_F(QnnHTPBackendTests, CompileApi_InputFile_OutputFile_InitializerHandler) { + const ORTCHAR_T* input_model_file = ORT_TSTR("./compileapi_inputfile_outputfile_initializerhandler.onnx"); + const ORTCHAR_T* output_model_file = ORT_TSTR("./compileapi_inputfile_outputfile_initializerhandler_ctx.onnx"); + const ORTCHAR_T* initializer_file = ORT_TSTR("./compileapi_inputfile_outputfile_initializerhandler.bin"); + std::filesystem::remove(input_model_file); + std::filesystem::remove(output_model_file); + std::filesystem::remove(initializer_file); + + // Create a test model and save it to a file. + TestModel test_model; + CreateTestModel(BuildGraphWithQAndNonQ(false), 21, logging::Severity::kERROR, test_model); + ASSERT_STATUS_OK(test_model.Save(input_model_file)); + + // Initialize session options with QNN EP + Ort::SessionOptions so; + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + provider_options["offload_graph_io_quantization"] = "0"; + so.AppendExecutionProvider("QNN", provider_options); + + // Open a file to store external initializers. ORT will call our handler function for every initializer. + ASSERT_FALSE(std::filesystem::exists(initializer_file)); + std::ofstream outfile(initializer_file, std::ios::binary); + CustomInitializerHandlerState custom_state = {initializer_file, &outfile}; + + // Create model compilation options from the session options. + Ort::ModelCompilationOptions compile_options(*ort_env, so); + compile_options.SetInputModelPath(input_model_file); + compile_options.SetOutputModelPath(output_model_file); + compile_options.SetOutputModelGetInitializerLocationFunc(TestHandleInitializerDataFunc, + reinterpret_cast(&custom_state)); + compile_options.SetEpContextEmbedMode(true); + compile_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC); + + // Compile the model. + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + ASSERT_TRUE(status.IsOK()) << status.GetErrorMessage(); + outfile.flush(); + outfile.close(); + + ASSERT_TRUE(std::filesystem::exists(initializer_file)); + ASSERT_TRUE(std::filesystem::exists(output_model_file)); + CheckEpContextNodeCounts(output_model_file, 2, 2); +} + +static OrtStatus* ORT_API_CALL ReuseExternalInitializers(void* state, + const char* /*initializer_name*/, + const OrtValue* /*initializer_value*/, + const OrtExternalInitializerInfo* external_info, + OrtExternalInitializerInfo** new_external_info) { + Ort::Status final_status{nullptr}; + + ORT_TRY { + // If the original initializer was stored in an external file, keep it there (just for testing). + if (external_info != nullptr) { + Ort::ConstExternalInitializerInfo info(external_info); + auto location = info.GetFilePath(); + int64_t offset = info.GetFileOffset(); + size_t byte_size = info.GetByteSize(); + + Ort::ExternalInitializerInfo new_info(nullptr); + Ort::Status status = Ort::ExternalInitializerInfo::Create(location.c_str(), offset, byte_size, new_info); + if (!status.IsOK()) { + return status.release(); + } + + *new_external_info = new_info.release(); + + // Keep track of number of reused external initializers so that we can assert + // that we reused the expected number of initializers. + // THIS IS TEST CODE. An application would not do this. + size_t* num_reused_ext_initializers = reinterpret_cast(state); + *num_reused_ext_initializers += 1; + + return nullptr; + } + + // If not originally external, save it within the generated compiled model + *new_external_info = nullptr; + } + ORT_CATCH(const Ort::Exception& ex) { + ORT_HANDLE_EXCEPTION(([&ex, &final_status]() { + final_status = Ort::Status{ex}; + })); + } + ORT_CATCH(const std::exception& ex) { + ORT_HANDLE_EXCEPTION(([&ex, &final_status]() { + final_status = Ort::Status(ex.what(), ORT_FAIL); + })); + } + + return final_status.release(); +} + +// Test using the CompileModel() API with settings: +// - input model comes from a file +// - write output model to a file +// - Use callback to specify where each initializer is stored. We'll reuse external initializers +// from original model! +TEST_F(QnnHTPBackendTests, CompileApi_InitializerHandler_ReuseExternalInitializers) { + const ORTCHAR_T* input_model_file = ORT_TSTR("testdata/conv_qdq_external_ini.onnx"); + const ORTCHAR_T* output_model_file = ORT_TSTR("testdata/conv_qdq_external_ini_reuse_ctx.onnx"); + std::filesystem::remove(output_model_file); + + size_t num_reused_ext_initializers = 0; + + // Create model compilation options from the session options. + Ort::SessionOptions so; + Ort::ModelCompilationOptions compile_options(*ort_env, so); + compile_options.SetInputModelPath(input_model_file); + compile_options.SetOutputModelPath(output_model_file); + compile_options.SetOutputModelGetInitializerLocationFunc(ReuseExternalInitializers, + reinterpret_cast(&num_reused_ext_initializers)); + compile_options.SetEpContextEmbedMode(true); + + // Compile the model. + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + ASSERT_TRUE(status.IsOK()) << status.GetErrorMessage(); + ASSERT_TRUE(std::filesystem::exists(output_model_file)); + std::filesystem::remove(output_model_file); + + ASSERT_EQ(num_reused_ext_initializers, 2); // Reused external conv weight and bias. +} + #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) } // namespace test diff --git a/onnxruntime/test/python/onnxruntime_test_python_compile_api.py b/onnxruntime/test/python/onnxruntime_test_python_compile_api.py index ed3cd882d7e00..e46cdb4f98850 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_compile_api.py +++ b/onnxruntime/test/python/onnxruntime_test_python_compile_api.py @@ -275,6 +275,149 @@ def compile_and_get_op_counts( ) self.assertEqual(op_counts_1["Cast"], 8) + def test_compile_from_file_to_stream(self): + """ + Tests compiling a model (from files) to an output stream using a custom write functor. + """ + provider = None + provider_options = dict() + if "QNNExecutionProvider" in available_providers: + provider = "QNNExecutionProvider" + provider_options["backend_type"] = "htp" + + input_model_path = get_name("nhwc_resize_scales_opset18.onnx") + output_model_path = os.path.join(self._tmp_dir_path, "model.compiled.stream.onnx") + + with open(output_model_path, "wb") as output_fd: + # User's custom write functor. Writes the model to a file. + def my_write_func(buffer: bytes): + self.assertGreater(len(buffer), 0) + output_fd.write(buffer) + + session_options = onnxrt.SessionOptions() + if provider: + session_options.add_provider(provider, provider_options) + + model_compiler = onnxrt.ModelCompiler( + session_options, + input_model_path, + embed_compiled_data_into_model=True, + external_initializers_file_path=None, + ) + model_compiler.compile_to_stream(my_write_func) + + self.assertTrue(os.path.exists(output_model_path)) + + def test_compile_to_stream_that_raises_exception(self): + """ + Tests compiling a model to an output stream that always raises an exception. + """ + input_model_path = get_name("nhwc_resize_scales_opset18.onnx") + + # User's custom write functor that raises an exception. + test_py_error_message = "My Python Error" + + def my_write_func(buffer: bytes): + self.assertGreater(len(buffer), 0) + raise ValueError(test_py_error_message) + + session_options = onnxrt.SessionOptions() + model_compiler = onnxrt.ModelCompiler( + session_options, + input_model_path, + embed_compiled_data_into_model=True, + external_initializers_file_path=None, + ) + + # Try to compile and expect ORT to raise a Fail exception that contains our message. + with self.assertRaises(Fail) as context: + model_compiler.compile_to_stream(my_write_func) + self.assertIn(test_py_error_message, str(context.exception)) + + def test_compile_with_basic_initializer_location_func(self): + """ + Tests compiling a model using a custom initializer handler that stores initializers + in an external file. + """ + input_model_path = get_name("conv_qdq_external_ini.onnx") + output_model_path = os.path.join(self._tmp_dir_path, "conv_qdq.init_handler.onnx") + initializer_file_path = os.path.join(self._tmp_dir_path, "conv_qdq.init_handler.bin") + + if os.path.exists(output_model_path): + os.remove(output_model_path) + + if os.path.exists(initializer_file_path): + os.remove(initializer_file_path) + + with open(initializer_file_path, "wb") as ext_init_file: + + def store_large_initializer_externally( + initializer_name: str, + initializer_value: onnxrt.OrtValue, + external_info: onnxrt.OrtExternalInitializerInfo | None, + ) -> onnxrt.OrtExternalInitializerInfo | None: + self.assertTrue(initializer_name) # Should have valid name + byte_size = initializer_value.tensor_size_in_bytes() + + if byte_size < 64: + return None # Store small initializer within compiled model. + + # Else, write initializer to new external file. + value_np = initializer_value.numpy() + file_offset = ext_init_file.tell() + ext_init_file.write(value_np.tobytes()) + return onnxrt.OrtExternalInitializerInfo(initializer_file_path, file_offset, byte_size) + + session_options = onnxrt.SessionOptions() + model_compiler = onnxrt.ModelCompiler( + session_options, + input_model_path, + embed_compiled_data_into_model=True, + external_initializers_file_path=None, + get_initializer_location_func=store_large_initializer_externally, + ) + model_compiler.compile_to_file(output_model_path) + + self.assertTrue(os.path.exists(output_model_path)) + self.assertTrue(os.path.exists(initializer_file_path)) + + def test_compile_with_initializer_func_that_reuses(self): + """ + Tests compiling a model using a custom initializer handler that reuses external initializer files. + """ + input_model_path = get_name("conv_qdq_external_ini.onnx") + output_model_path = os.path.join(self._tmp_dir_path, "conv_qdq.init_handler_reuse.onnx") + + if os.path.exists(output_model_path): + os.remove(output_model_path) + + # Function that reuses external initializer files for the compiled model. + def reuse_external_initializers( + initializer_name: str, + initializer_value: onnxrt.OrtValue, + external_info: onnxrt.OrtExternalInitializerInfo | None, + ) -> onnxrt.OrtExternalInitializerInfo | None: + self.assertTrue(initializer_name) # Should have valid name + self.assertNotEqual(initializer_value.data_ptr(), 0) + self.assertGreater(initializer_value.tensor_size_in_bytes(), 0) + if external_info is not None: + # Original initializer is stored externally. + # Make the initializer in the compiled model use the same external file + return external_info + + return None # Otherwise, make a copy of the initializer and store it within compiled model. + + session_options = onnxrt.SessionOptions() + model_compiler = onnxrt.ModelCompiler( + session_options, + input_model_path, + embed_compiled_data_into_model=True, + external_initializers_file_path=None, + get_initializer_location_func=reuse_external_initializers, + ) + model_compiler.compile_to_file(output_model_path) + self.assertTrue(os.path.exists(output_model_path)) + def test_fail_load_uncompiled_model_and_then_compile(self): """ Tests compiling scenario: