diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs
index 456097ff9db9a..bde39d9c6e6cc 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs
@@ -4,6 +4,7 @@
namespace Microsoft.ML.OnnxRuntime
{
using System;
+ using System.Diagnostics;
using System.Runtime.InteropServices;
///
@@ -22,7 +23,7 @@ public enum OrtCompileApiFlags : uint
/// This class is used to set options for model compilation, and to produce a compiled model using those options.
/// See https://onnxruntime.ai/docs/api/c/ for further details of various options.
///
- public class OrtModelCompilationOptions : SafeHandle
+ public class OrtModelCompilationOptions : IDisposable
{
///
/// Create a new OrtModelCompilationOptions object from SessionOptions.
@@ -31,11 +32,10 @@ public class OrtModelCompilationOptions : SafeHandle
/// to enable graph optimizations.
/// SessionOptions instance to read settings from.
public OrtModelCompilationOptions(SessionOptions sessionOptions)
- : base(IntPtr.Zero, true)
{
NativeApiStatus.VerifySuccess(
NativeMethods.CompileApi.OrtCreateModelCompilationOptionsFromSessionOptions(
- OrtEnv.Instance().Handle, sessionOptions.Handle, out handle));
+ OrtEnv.Instance().Handle, sessionOptions.Handle, out _handle));
}
///
@@ -43,7 +43,7 @@ public OrtModelCompilationOptions(SessionOptions sessionOptions)
///
public void CompileModel()
{
- NativeApiStatus.VerifySuccess(NativeMethods.CompileApi.OrtCompileModel(OrtEnv.Instance().Handle, handle));
+ NativeApiStatus.VerifySuccess(NativeMethods.CompileApi.OrtCompileModel(OrtEnv.Instance().Handle, _handle));
}
@@ -55,7 +55,7 @@ public void SetInputModelPath(string path)
{
var platformPath = NativeOnnxValueHelper.GetPlatformSerializedString(path);
NativeApiStatus.VerifySuccess(
- NativeMethods.CompileApi.OrtModelCompilationOptions_SetInputModelPath(handle, platformPath));
+ NativeMethods.CompileApi.OrtModelCompilationOptions_SetInputModelPath(_handle, platformPath));
}
///
@@ -67,7 +67,7 @@ public void SetInputModelFromBuffer(byte[] buffer)
{
NativeApiStatus.VerifySuccess(
NativeMethods.CompileApi.OrtModelCompilationOptions_SetInputModelFromBuffer(
- handle, buffer, (UIntPtr)buffer.Length));
+ _handle, buffer, (UIntPtr)buffer.Length));
}
///
@@ -78,7 +78,7 @@ public void SetOutputModelPath(string path)
{
var platformPath = NativeOnnxValueHelper.GetPlatformSerializedString(path);
NativeApiStatus.VerifySuccess(
- NativeMethods.CompileApi.OrtModelCompilationOptions_SetOutputModelPath(handle, platformPath));
+ NativeMethods.CompileApi.OrtModelCompilationOptions_SetOutputModelPath(_handle, platformPath));
}
@@ -93,7 +93,7 @@ public void SetOutputModelExternalInitializersFile(string filePath, ulong thresh
var platformPath = NativeOnnxValueHelper.GetPlatformSerializedString(filePath);
NativeApiStatus.VerifySuccess(
NativeMethods.CompileApi.OrtModelCompilationOptions_SetOutputModelExternalInitializersFile(
- handle, platformPath, new UIntPtr(threshold)));
+ _handle, platformPath, new UIntPtr(threshold)));
}
// TODO: In order to use this to create an InferenceSession without copying bytes we need more infrastructure.
@@ -108,7 +108,7 @@ internal void SetOutputModelBuffer(OrtAllocator allocator,
{
NativeApiStatus.VerifySuccess(
NativeMethods.CompileApi.OrtModelCompilationOptions_SetOutputModelBuffer(
- handle, allocator.Pointer, ref outputModelBufferPtr, ref outputModelBufferSizePtr));
+ _handle, allocator.Pointer, ref outputModelBufferPtr, ref outputModelBufferSizePtr));
}
///
@@ -119,7 +119,7 @@ internal void SetOutputModelBuffer(OrtAllocator allocator,
public void SetEpContextEmbedMode(bool embed)
{
NativeApiStatus.VerifySuccess(
- NativeMethods.CompileApi.OrtModelCompilationOptions_SetEpContextEmbedMode(handle, embed));
+ NativeMethods.CompileApi.OrtModelCompilationOptions_SetEpContextEmbedMode(_handle, embed));
}
///
@@ -129,7 +129,7 @@ public void SetEpContextEmbedMode(bool embed)
public void SetFlags(OrtCompileApiFlags flags)
{
NativeApiStatus.VerifySuccess(
- NativeMethods.CompileApi.OrtModelCompilationOptions_SetFlags(handle, (uint)flags));
+ NativeMethods.CompileApi.OrtModelCompilationOptions_SetFlags(_handle, (uint)flags));
}
///
@@ -145,7 +145,7 @@ public void SetEpContextBinaryInformation(string outputDirectory, string modelNa
var platformModelName = NativeOnnxValueHelper.GetPlatformSerializedString(modelName);
NativeApiStatus.VerifySuccess(
NativeMethods.CompileApi.OrtModelCompilationOptions_SetEpContextBinaryInformation(
- handle, platformOutputDirectory, platformModelName));
+ _handle, platformOutputDirectory, platformModelName));
}
///
@@ -156,26 +156,352 @@ public void SetGraphOptimizationLevel(GraphOptimizationLevel graphOptimizationLe
{
NativeApiStatus.VerifySuccess(
NativeMethods.CompileApi.OrtModelCompilationOptions_SetGraphOptimizationLevel(
- handle, graphOptimizationLevel));
+ _handle, graphOptimizationLevel));
}
- internal IntPtr Handle => handle;
+ ///
+ /// Delegate to write/save a buffer containing ONNX model bytes to a custom destination. The delegate
+ /// may be called repeatedly until the entire output model has been written out. Each call to the delegate
+ /// is expected to consume the entire buffer.
+ ///
+ /// The buffer to write out.
+ ///
+ public delegate void WriteBufferToDestinationDelegate(ReadOnlySpan buffer);
+
+ ///
+ /// Sets a delegate that is called by ORT to write out the output model's serialized ONNX bytes.
+ /// The provided delegate may be called repeatedly until the entire output model has been written out.
+ /// Each call to the delegate is expected to consume/handle the entire input buffer.
+ ///
+ /// The delegate called by ORT to write out the model.
+ public void SetOutputModelWriteDelegate(WriteBufferToDestinationDelegate writeBufferDelegate)
+ {
+ _writeBufferToDestinationDelegateState?.Dispose();
+ _writeBufferToDestinationDelegateState =
+ new DelegateResources(
+ new WriteBufferToDestinationConnector(writeBufferDelegate),
+ new NativeMethods.DOrtWriteBufferToDestinationDelegate(
+ WriteBufferToDestinationConnector.WriteBufferToDestinationDelegateWrapper));
+
+ IntPtr funcPtr = _writeBufferToDestinationDelegateState.GetFunctionPointerForDelegate();
+
+ NativeApiStatus.VerifySuccess(
+ NativeMethods.CompileApi.OrtModelCompilationOptions_SetOutputModelWriteFunc(
+ _handle,
+ funcPtr,
+ _writeBufferToDestinationDelegateState.GetConnectorHandleAsPointer()));
+ }
+
+ ///
+ /// Delegate called by ORT for every initializer when generating the compiled model.
+ /// The delegate allows the user to determine whether the initializer should be stored within the compiled
+ /// model or externally in a file. If the delegate chooses to store an initializer externally, the delegate
+ /// implementation is responsible for writing the initializer data to a file.
+ ///
+ /// The initializer's name.
+ /// The readonly OrtValue instance containing the data, type, and
+ /// shape of the initializer.
+ /// May be null. If the initializer is originally stored externally,
+ /// this contains the file path, file offset, and data size. Otherwise, this is null.
+ /// A new OrtExternalInitializerInfo indicating the new location of the initializer.
+ /// Returns null if the initializer should be stored within the generated compiled model.
+ /// The return value may be null.
+ ///
+ public delegate OrtExternalInitializerInfo GetInitializerLocationDelegate(
+ string initializerName,
+ IReadOnlyOrtValue initializerValue,
+ IReadOnlyExternalInitializerInfo originalInitializerLocation);
+
+ ///
+ /// Sets a delegate that is called by ORT for every initializer when generating the compiled model.
+ /// The delegate allows the user to determine whether the initializer should be stored within the compiled
+ /// model or externally in a file. If the delegate chooses to store an initializer externally, the delegate
+ /// implementation is responsible for writing the initializer data to a file.
+ ///
+ /// The delegate called by ORT for every initializer.
+ public void SetOutputModelGetInitializerLocationDelegate(
+ GetInitializerLocationDelegate getInitializerLocationDelegate)
+ {
+ _getInitializerLocationDelegateState?.Dispose();
+ _getInitializerLocationDelegateState =
+ new DelegateResources(
+ new GetInitializerLocationConnector(getInitializerLocationDelegate),
+ new NativeMethods.DOrtGetInitializerLocationDelegate(
+ GetInitializerLocationConnector.GetInitializerLocationDelegateWrapper));
+
+ IntPtr funcPtr = _getInitializerLocationDelegateState.GetFunctionPointerForDelegate();
+
+ NativeApiStatus.VerifySuccess(
+ NativeMethods.CompileApi.OrtModelCompilationOptions_SetOutputModelGetInitializerLocationFunc(
+ _handle,
+ funcPtr,
+ _getInitializerLocationDelegateState.GetConnectorHandleAsPointer()));
+ }
+
+ #region Delegate helpers
+ ///
+ /// Class to bridge the C# and native worlds for the "write buffer to destination" delegate
+ ///
+ private class WriteBufferToDestinationConnector
+ {
+ private readonly WriteBufferToDestinationDelegate _userDelegate;
+
+ internal WriteBufferToDestinationConnector(WriteBufferToDestinationDelegate writeBufferDelegate)
+ {
+ _userDelegate = writeBufferDelegate;
+ }
+
+ public static IntPtr WriteBufferToDestinationDelegateWrapper(IntPtr /* void* */ state,
+ IntPtr /* const void* */ buffer,
+ UIntPtr /* size_t */ bufferNumBytes)
+ {
+ try
+ {
+
+ WriteBufferToDestinationConnector connector = (WriteBufferToDestinationConnector)
+ GCHandle.FromIntPtr(state).Target;
+ ReadOnlySpan bufferSpan;
+
+ unsafe
+ {
+ // NOTE: A Span can only view 2GB of data. This is fine because ORT does not write out
+ // chunks that large. However, if we ever need to, the solution is to just write a loop here
+ // that repeatedly calls the delegate with smaller chunks of data.
+ bufferSpan = new ReadOnlySpan(buffer.ToPointer(), checked((int)bufferNumBytes));
+ }
+
+ connector._userDelegate(bufferSpan);
+ }
+ catch (Exception ex)
+ {
+ var error = $"The C# WriteBufferToDestination delegate threw an exception: {ex.Message}";
+ IntPtr status = NativeMethods.OrtCreateStatus((uint)ErrorCode.Fail,
+ NativeOnnxValueHelper.StringToZeroTerminatedUtf8(error));
+ return status;
+ }
+
+ return IntPtr.Zero;
+ }
+ }
+
+ ///
+ /// Class to bridge the C# and native worlds for the "get initializer location" delegate
+ ///
+ private class GetInitializerLocationConnector
+ {
+ private readonly GetInitializerLocationDelegate _userDelegate;
+
+ internal GetInitializerLocationConnector(GetInitializerLocationDelegate getInitializerLocationDelegate)
+ {
+ _userDelegate = getInitializerLocationDelegate;
+ }
+
+ public static IntPtr GetInitializerLocationDelegateWrapper(
+ IntPtr /* void* */ state,
+ IntPtr /* const char* */ initializerName,
+ IntPtr /* const OrtValue* */ initializerValue,
+ IntPtr /* const OrtExternalInitializerInfo* */ originalInitializerLocation,
+ out IntPtr /* OrtExternalInitializerInfo** */ newInitializerLocationOutput)
+ {
+ newInitializerLocationOutput = IntPtr.Zero;
+
+ try
+ {
+
+ GetInitializerLocationConnector connector = (GetInitializerLocationConnector)GCHandle.
+ FromIntPtr(state).Target;
+ string utf8InitializerName = NativeOnnxValueHelper.StringFromNativeUtf8(initializerName);
+ IReadOnlyOrtValue readOnlyInitializerValue = new OrtValue(initializerValue, owned: false);
+ IReadOnlyExternalInitializerInfo readOnlyOriginalInitializerLocation = null;
+
+ if (originalInitializerLocation != IntPtr.Zero)
+ {
+ readOnlyOriginalInitializerLocation = new OrtExternalInitializerInfo(
+ originalInitializerLocation, ownsHandle: false);
+ }
+ // Call user's delegate, which may return the new location of the initializer.
+ OrtExternalInitializerInfo newInitializerLocation = connector._userDelegate(
+ utf8InitializerName, readOnlyInitializerValue, readOnlyOriginalInitializerLocation);
+
+ if (newInitializerLocation != null)
+ {
+ // Delegate returned info about a new location for the initializer.
+ // Can't guarantee that the new external info returned by user's delegate is not referenced
+ // by other C# code. ORT expects to own the new external info, so create a copy here and
+ // give it to ORT.
+ string newFilePath = newInitializerLocation.GetFilePath();
+ byte[] newFilePathBytes = NativeOnnxValueHelper.GetPlatformSerializedString(newFilePath);
+
+ IntPtr status = NativeMethods.OrtCreateExternalInitializerInfo(
+ newFilePathBytes,
+ newInitializerLocation.GetFileOffset(),
+ (UIntPtr)newInitializerLocation.GetByteSize(),
+ out newInitializerLocationOutput);
+
+ if (status != IntPtr.Zero)
+ {
+ return status;
+ }
+ }
+ else
+ {
+ // User's delegate did not return a new location for the initializer. ORT will store initializer
+ // within the generated compiled model.
+ newInitializerLocationOutput = IntPtr.Zero;
+ }
+ }
+ catch (Exception ex)
+ {
+ var error = $"The C# GetInitializerLocation delegate threw an exception: {ex.Message}";
+ IntPtr status = NativeMethods.OrtCreateStatus((uint)ErrorCode.Fail,
+ NativeOnnxValueHelper.StringToZeroTerminatedUtf8(error));
+ return status;
+ }
+
+ return IntPtr.Zero;
+ }
+ }
///
- /// Indicates whether the native handle is invalid.
+ /// Disposable class that stores resources for a delegate provided by the user.
///
- public override bool IsInvalid => handle == IntPtr.Zero;
+ /// The type of the connector class
+ /// (e.g., WriteBufferToDestinationConnector)
+ /// The type of the native delegate.
+ private class DelegateResources : IDisposable
+ where Connector : class
+ where Delegate : class
+ {
+ public DelegateResources(Connector connector, Delegate @delegate)
+ {
+ _connector = connector;
+ _delegate = @delegate;
+ _connectorHandle = GCHandle.Alloc(_connector);
+ _delegateHandle = GCHandle.Alloc(_delegate);
+ }
+
+ internal IntPtr GetFunctionPointerForDelegate()
+ {
+ return Marshal.GetFunctionPointerForDelegate(_delegate);
+ }
+
+ internal IntPtr GetConnectorHandleAsPointer()
+ {
+ return GCHandle.ToIntPtr(_connectorHandle);
+ }
+
+ public void Dispose()
+ {
+ Dispose(true);
+ GC.SuppressFinalize(this);
+ }
+
+ protected virtual void Dispose(bool disposing)
+ {
+ if (_disposed)
+ {
+ return;
+ }
+
+ if (disposing)
+ {
+ // Dispose other children disposables. We have none.
+ }
+ if (_connectorHandle.IsAllocated)
+ {
+ _connectorHandle.Free();
+ _connector = null;
+ }
+
+ if (_delegateHandle.IsAllocated)
+ {
+ _delegateHandle.Free();
+ _delegate = null;
+ }
+
+ _disposed = true;
+ }
+
+ ~DelegateResources()
+ {
+ Dispose(false);
+ }
+
+ private Connector _connector = null;
+ private Delegate _delegate = null;
+ private GCHandle _connectorHandle = default;
+ private GCHandle _delegateHandle = default;
+ private bool _disposed = false;
+ }
+ #endregion
+
+ #region IDispose implementation
///
- /// Release the native instance of OrtModelCompilationOptions.
+ /// IDispose implementation.
///
- /// true
- protected override bool ReleaseHandle()
+ public void Dispose()
{
- NativeMethods.CompileApi.OrtReleaseModelCompilationOptions(handle);
- handle = IntPtr.Zero;
- return true;
+ Dispose(true);
+ GC.SuppressFinalize(this);
}
+
+ ///
+ /// IDispose implementation
+ ///
+ /// True if Dispose() has been called by the user-side code. False if
+ /// called by the runtime from inside the finalizer.
+ protected virtual void Dispose(bool disposing)
+ {
+ if (_disposed)
+ {
+ return;
+ }
+
+ if (disposing)
+ {
+ _writeBufferToDestinationDelegateState?.Dispose();
+ _getInitializerLocationDelegateState?.Dispose();
+ }
+
+ Debug.Assert(_handle != IntPtr.Zero);
+ NativeMethods.CompileApi.OrtReleaseModelCompilationOptions(_handle);
+ _handle = IntPtr.Zero;
+ _disposed = true;
+ }
+
+ ///
+ /// Finalizer that releases the native handle if not already released by Dispose().
+ ///
+ ~OrtModelCompilationOptions()
+ {
+ Dispose(false);
+ }
+ #endregion
+
+ ///
+ /// Handle to the native OrtModelCompilationOptions object.
+ ///
+ private IntPtr _handle;
+
+ ///
+ /// True if this OrtModelCompilationOptions instance has already been disposed.
+ ///
+ private bool _disposed = false;
+
+ ///
+ /// Stores delegate state for the "write buffer to destination" delegate.
+ ///
+ private DelegateResources
+ _writeBufferToDestinationDelegateState = null;
+
+ ///
+ /// Stores delegate state for the "get initializer location" delegate.
+ ///
+ private DelegateResources
+ _getInitializerLocationDelegateState = null;
}
-}
\ No newline at end of file
+}
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs
index 9d25d96bdaa5a..84020d84c9e73 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs
@@ -23,6 +23,8 @@ public struct OrtCompileApi
public IntPtr ModelCompilationOptions_SetFlags;
public IntPtr ModelCompilationOptions_SetEpContextBinaryInformation;
public IntPtr ModelCompilationOptions_SetGraphOptimizationLevel;
+ public IntPtr ModelCompilationOptions_SetOutputModelWriteFunc;
+ public IntPtr ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc;
}
internal class NativeMethods
@@ -118,6 +120,22 @@ public DOrtModelCompilationOptions_SetEpContextBinaryInformation
public DOrtModelCompilationOptions_SetGraphOptimizationLevel
OrtModelCompilationOptions_SetGraphOptimizationLevel;
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetOutputModelWriteFunc(
+ IntPtr /* OrtModelCompilationOptions* */ options,
+ IntPtr /* DOrtWriteBufferDelegate */ writeFunc,
+ IntPtr /* void* */ state);
+ public DOrtModelCompilationOptions_SetOutputModelWriteFunc
+ OrtModelCompilationOptions_SetOutputModelWriteFunc;
+
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetOutputModelGetInitializerLocationFunc(
+ IntPtr /* OrtModelCompilationOptions* */ options,
+ IntPtr /* DOrtHandleInitializerDataDelegate */ handleInitializerFunc,
+ IntPtr /* void* */ state);
+ public DOrtModelCompilationOptions_SetOutputModelGetInitializerLocationFunc
+ OrtModelCompilationOptions_SetOutputModelGetInitializerLocationFunc;
+
internal NativeMethods(OnnxRuntime.NativeMethods.DOrtGetCompileApi getCompileApi)
{
@@ -188,6 +206,17 @@ internal NativeMethods(OnnxRuntime.NativeMethods.DOrtGetCompileApi getCompileApi
_compileApi.ModelCompilationOptions_SetGraphOptimizationLevel,
typeof(DOrtModelCompilationOptions_SetGraphOptimizationLevel));
+ OrtModelCompilationOptions_SetOutputModelWriteFunc =
+ (DOrtModelCompilationOptions_SetOutputModelWriteFunc)Marshal.GetDelegateForFunctionPointer(
+ _compileApi.ModelCompilationOptions_SetOutputModelWriteFunc,
+ typeof(DOrtModelCompilationOptions_SetOutputModelWriteFunc));
+
+ OrtModelCompilationOptions_SetOutputModelGetInitializerLocationFunc =
+ (DOrtModelCompilationOptions_SetOutputModelGetInitializerLocationFunc)Marshal.
+ GetDelegateForFunctionPointer(
+ _compileApi.ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc,
+ typeof(DOrtModelCompilationOptions_SetOutputModelGetInitializerLocationFunc));
+
}
}
}
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
index 3c92400715740..53880308da261 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
@@ -450,6 +450,7 @@ public struct OrtApi
public IntPtr Graph_GetModelMetadata;
public IntPtr GetModelCompatibilityForEpDevices;
+ public IntPtr CreateExternalInitializerInfo;
}
internal static class NativeMethods
@@ -787,9 +788,35 @@ static NativeMethods()
api_.SessionOptionsSetEpSelectionPolicyDelegate,
typeof(DSessionOptionsSetEpSelectionPolicyDelegate));
+ OrtReleaseExternalInitializerInfo =
+ (DOrtReleaseExternalInitializerInfo)Marshal.GetDelegateForFunctionPointer(
+ api_.ReleaseExternalInitializerInfo,
+ typeof(DOrtReleaseExternalInitializerInfo));
+
+ OrtExternalInitializerInfo_GetFilePath =
+ (DOrtExternalInitializerInfo_GetFilePath)Marshal.GetDelegateForFunctionPointer(
+ api_.ExternalInitializerInfo_GetFilePath,
+ typeof(DOrtExternalInitializerInfo_GetFilePath));
+
+ OrtExternalInitializerInfo_GetFileOffset =
+ (DOrtExternalInitializerInfo_GetFileOffset)Marshal.GetDelegateForFunctionPointer(
+ api_.ExternalInitializerInfo_GetFileOffset,
+ typeof(DOrtExternalInitializerInfo_GetFileOffset));
+
+ OrtExternalInitializerInfo_GetByteSize =
+ (DOrtExternalInitializerInfo_GetByteSize)Marshal.GetDelegateForFunctionPointer(
+ api_.ExternalInitializerInfo_GetByteSize,
+ typeof(DOrtExternalInitializerInfo_GetByteSize));
+
OrtGetModelCompatibilityForEpDevices = (DOrtGetModelCompatibilityForEpDevices)Marshal.GetDelegateForFunctionPointer(
api_.GetModelCompatibilityForEpDevices,
typeof(DOrtGetModelCompatibilityForEpDevices));
+
+ OrtCreateExternalInitializerInfo =
+ (DOrtCreateExternalInitializerInfo)Marshal.GetDelegateForFunctionPointer(
+ api_.CreateExternalInitializerInfo,
+ typeof(DOrtCreateExternalInitializerInfo));
+
}
internal class NativeLib
@@ -2382,6 +2409,70 @@ out IntPtr lora_adapter
public delegate ref CompileApi.OrtCompileApi DOrtGetCompileApi();
#endif
public static DOrtGetCompileApi OrtGetCompileApi;
+
+ ///
+ /// Delegate called by ORT to write a buffer (ONNX model bytes) to a custom destination (e.g., file or stream).
+ ///
+ /// State that was provided in when the delegate was registered.
+ /// The buffer to write.
+ /// The size of the buffer in bytes.
+ /// OrtStatus*
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate IntPtr /* OrtStatus* */ DOrtWriteBufferToDestinationDelegate(
+ IntPtr /* void* */ state,
+ IntPtr /* const void* */ buffer,
+ UIntPtr /* size_t */ bufferNumBytes
+ );
+
+ ///
+ /// Function called by ORT to allow user to specify how an initializer should be saved while compiling
+ /// a model, that is, either written to an external file or stored within the model. ORT calls this function
+ /// for every initializer.
+ ///
+ /// State that was provided when the delegate was registered.
+ /// The initializer's name.
+ /// The OrtValue containing the initializer's data, type, and shape
+ /// The original initializer's location in an external file, or NULL.
+ /// Output parameter set to a new OrtExternalInitializerInfo instance
+ /// indicating the location where the function implementation stored the initializer data. If the function
+ /// implementation sets `newExternalInfo` to NULL, ORT stores the initializer within the generated model.
+ ///
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate IntPtr /* OrtStatus* */ DOrtGetInitializerLocationDelegate(
+ IntPtr /* void* */ state,
+ IntPtr /* const char* */ initializerName,
+ IntPtr /* const OrtValue* */ initializerValue,
+ IntPtr /* const OrtExternalInitializerInfo* */ externalInfo,
+ out IntPtr /* OrtExternalInitializerInfo** */ newExternalInfo
+ );
+
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate void DOrtReleaseExternalInitializerInfo(IntPtr /* OrtExternalInitializerInfo* */ info);
+
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate IntPtr /* OrtStatus* */ DOrtCreateExternalInitializerInfo(
+ byte[] /* const ORTCHAR_T* */ filePath,
+ long /* int64_t */ fileOffset,
+ UIntPtr /* size_t */ byteSize,
+ out IntPtr /* OrtExternalInitializerInfo** */ outInfo);
+
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate IntPtr /* const ORTCHAR_T* */ DOrtExternalInitializerInfo_GetFilePath(
+ IntPtr /* const OrtExternalInitializerInfo* */ info);
+
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate long /* int64_t */ DOrtExternalInitializerInfo_GetFileOffset(
+ IntPtr /* const OrtExternalInitializerInfo* */ info);
+
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate UIntPtr /* size_t */ DOrtExternalInitializerInfo_GetByteSize(
+ IntPtr /* const OrtExternalInitializerInfo* */ info);
+
+ public static DOrtReleaseExternalInitializerInfo OrtReleaseExternalInitializerInfo;
+ public static DOrtCreateExternalInitializerInfo OrtCreateExternalInitializerInfo;
+ public static DOrtExternalInitializerInfo_GetFilePath OrtExternalInitializerInfo_GetFilePath;
+ public static DOrtExternalInitializerInfo_GetFileOffset OrtExternalInitializerInfo_GetFileOffset;
+ public static DOrtExternalInitializerInfo_GetByteSize OrtExternalInitializerInfo_GetByteSize;
#endregion
#region Auto EP API related
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs
index fc14be00ee47b..4611428ea12ef 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs
@@ -150,6 +150,45 @@ internal static byte[] GetPlatformSerializedString(string str)
else
return StringToZeroTerminatedUtf8(str);
}
+
+ ///
+ /// Converts a null-terminated path string that is pointed to by the given IntPtr handle into
+ /// a C# UTF-16 string.
+ ///
+ /// A path string on Windows is utf-16, but utf-8 on other operating systems.
+ ///
+ ///
+ internal static string StringFromNativePathString(IntPtr strPtr)
+ {
+ if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
+ {
+ if (strPtr == IntPtr.Zero)
+ {
+ return string.Empty;
+ }
+
+ // Get length of utf16 string by checking for two 0 bytes in a row.
+ int length = 0;
+ while (Marshal.ReadInt16(strPtr, length * 2) != 0)
+ {
+ length += 1;
+ }
+
+ if (length == 0)
+ {
+ return string.Empty;
+ }
+
+ unsafe
+ {
+ return System.Text.Encoding.Unicode.GetString((byte*)strPtr, length * 2);
+ }
+ }
+ else
+ {
+ return StringFromNativeUtf8(strPtr);
+ }
+ }
}
// Guards an array of disposable objects on stack and disposes them in reverse order
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtExternalInitializerInfo.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtExternalInitializerInfo.shared.cs
new file mode 100644
index 0000000000000..aca16e939ce21
--- /dev/null
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtExternalInitializerInfo.shared.cs
@@ -0,0 +1,136 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+
+namespace Microsoft.ML.OnnxRuntime
+{
+ using System;
+ using System.Diagnostics;
+ using System.Runtime.InteropServices;
+
+ ///
+ /// Class to that stores information about the file location where an "external" initializer is stored.
+ ///
+ ///
+ public class OrtExternalInitializerInfo : SafeHandle, IReadOnlyExternalInitializerInfo
+ {
+ // Set to false when constructed with an externally managed constant handle owned by ORT.
+ private readonly bool _ownsHandle = true;
+
+ ///
+ /// Create a new OrtExternalInitializerInfo instance.
+ ///
+ /// The path to the file that stores the initializer data.
+ /// The byte offset in the file where the data is stored.
+ /// The size of the data (in bytes) within the file.
+ public OrtExternalInitializerInfo(string filePath, long fileOffset, long byteSize)
+ : base(IntPtr.Zero, ownsHandle: true)
+ {
+ var platformFilePath = NativeOnnxValueHelper.GetPlatformSerializedString(filePath);
+ NativeApiStatus.VerifySuccess(
+ NativeMethods.OrtCreateExternalInitializerInfo(platformFilePath, fileOffset, (UIntPtr)byteSize, out handle));
+ _ownsHandle = true;
+ }
+
+ ///
+ /// Create a new OrtExternalInitializerInfo instance from an existing native OrtExternalInitializerInfo handle.
+ ///
+ /// Native OrtExternalInitializerInfo handle.
+ /// True if the OrtExternalInitializerInfo instance owns the native handle.
+ /// Defaults to false.
+ internal OrtExternalInitializerInfo(IntPtr constHandle, bool ownsHandle = false)
+ : base(IntPtr.Zero, ownsHandle)
+ {
+ Debug.Assert(constHandle != IntPtr.Zero);
+ SetHandle(constHandle);
+ _ownsHandle = ownsHandle;
+ }
+
+ ///
+ /// Get the file path to the file that store's the initializer's data.
+ ///
+ ///
+ /// The path is relative to the filesystem directory where the ONNX model was stored.
+ ///
+ /// The file path.
+ public string GetFilePath()
+ {
+ IntPtr filePathPtr = NativeMethods.OrtExternalInitializerInfo_GetFilePath(handle);
+ if (filePathPtr == IntPtr.Zero)
+ {
+ return string.Empty;
+ }
+
+ return NativeOnnxValueHelper.StringFromNativePathString(filePathPtr);
+ }
+
+ ///
+ /// Get the byte offset within the file where the initializer's data is stored.
+ ///
+ /// The file offset location.
+ public long GetFileOffset()
+ {
+ return NativeMethods.OrtExternalInitializerInfo_GetFileOffset(handle);
+ }
+
+ ///
+ /// Get the size in bytes of the initializer's data within the file.
+ ///
+ /// The size in bytes of the initializer data.
+ public long GetByteSize()
+ {
+ UIntPtr byteSize = NativeMethods.OrtExternalInitializerInfo_GetByteSize(handle);
+ return checked((long)byteSize);
+ }
+
+ ///
+ /// Indicates whether the native handle is invalid.
+ ///
+ public override bool IsInvalid { get { return handle == IntPtr.Zero; } }
+
+ ///
+ /// Release the native instance of OrtExternalInitializerInfo if we own it.
+ ///
+ /// true on success and false on error.
+ protected override bool ReleaseHandle()
+ {
+ if (!_ownsHandle)
+ {
+ // Return false to indicate an error.
+ // ReleaseHandle() should not be called on a const handle that this class does not own.
+ return false;
+ }
+
+ NativeMethods.OrtReleaseExternalInitializerInfo(handle);
+ handle = IntPtr.Zero;
+ return true;
+ }
+ }
+
+ ///
+ /// Interface for all readonly methods implemented by OrtExternalInitializerInfo.
+ ///
+ public interface IReadOnlyExternalInitializerInfo
+ {
+ ///
+ /// Get the file path to the file that store's the initializer's data.
+ ///
+ ///
+ /// The path is relative to the filesystem directory where the ONNX model was stored.
+ ///
+ /// The file path.
+ string GetFilePath();
+
+ ///
+ /// Get the byte offset within the file where the initializer's data is stored.
+ ///
+ /// The file offset location.
+ long GetFileOffset();
+
+ ///
+ /// Get the size in bytes of the initializer's data within the file.
+ ///
+ /// The size in bytes of the initializer data.
+ long GetByteSize();
+ }
+}
\ No newline at end of file
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs
index 01ee3aa5ae753..d848c63450ec1 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs
@@ -33,6 +33,147 @@ public enum OnnxValueType
ONNX_TYPE_OPTIONAL = 6, // It's an optional type that designates anything above (except UNKNOWN)
}
+ ///
+ /// Interface for all readonly methods implemented by OrtValue.
+ ///
+ public interface IReadOnlyOrtValue
+ {
+ ///
+ /// Get the ONNX value type for the OrtValue (e.g., OnnxValueType.ONNX_TYPE_TENSOR).
+ ///
+ /// OnnxValueType
+ OnnxValueType OnnxType { get; }
+
+ ///
+ /// Returns true if OrtValue contains a tensor
+ ///
+ /// true if tensor
+ bool IsTensor { get; }
+
+ ///
+ /// Returns true if OrtValue contains a sparse tensor
+ ///
+ /// true if sparse tensor
+ bool IsSparseTensor { get; }
+
+ ///
+ /// Returns type information about the contained OnnxValue.
+ ///
+ /// a disposable instance of OrtTypeInfo
+ OrtTypeInfo GetTypeInfo();
+
+ ///
+ /// Obtains Tensor And Type Information from the OrtValue iff it contains a tensor.
+ /// Valid only for OrtValues that contain a tensor.
+ ///
+ /// A disposable instance of OrtTensorTypeAndShapeInfo
+ OrtTensorTypeAndShapeInfo GetTensorTypeAndShape();
+
+ ///
+ /// Returns the size of the tensor data in bytes.
+ ///
+ /// size of the tensor data in bytes
+ long GetTensorSizeInBytes();
+
+ ///
+ /// Returns OrtMemoryInfo iff this OrtValue contains a tensor or a sparse tensor.
+ ///
+ /// OrtMemoryInfo that describes the underlying memory allocation
+ ///
+ OrtMemoryInfo GetTensorMemoryInfo();
+
+ ///
+ /// Returns a ReadOnlySpan over tensor native buffer that
+ /// provides a read-only view.
+ ///
+ /// Note, that the memory may be device allocated and, therefore, not accessible from the CPU.
+ /// To get memory descriptor use GetTensorMemoryInfo().
+ ///
+ /// OrtValue must contain a non-string tensor.
+ /// The span is valid as long as the OrtValue instance is alive (not disposed).
+ ///
+ ///
+ /// ReadOnlySpan
+ ///
+ ReadOnlySpan GetTensorDataAsSpan() where T : unmanaged;
+
+#if NET8_0_OR_GREATER
+ ///
+ /// Returns a ReadOnlyTensorSpan over tensor native buffer that
+ /// provides a read-only view.
+ ///
+ /// Note, that the memory may be device allocated and, therefore, not accessible from the CPU.
+ /// To get memory descriptor use GetTensorMemoryInfo().
+ ///
+ /// OrtValue must contain a non-string tensor.
+ /// The span is valid as long as the OrtValue instance is alive (not disposed).
+ ///
+ ///
+ /// ReadOnlySpan
+ ///
+ [Experimental("SYSLIB5001")]
+ SystemNumericsTensors.ReadOnlyTensorSpan GetTensorDataAsTensorSpan() where T : unmanaged;
+#endif
+
+ ///
+ /// Valid for composite ML types like map, sequence.
+ /// Returns 2 for map (keys, values) and N for sequence, where N is the number of elements
+ /// in the sequence.
+ ///
+ /// Element count
+ int GetValueCount();
+
+ ///
+ /// For non tensors return OrtValue element at the specified index.
+ /// For maps only indices 0 and 1 are valid. For sequences, [0..N) are valid.
+ /// See GetValueCount() to determine the valid range.
+ ///
+ ///
+ /// allocator to use
+ /// OrtValue disposable instance that points to the corresponding element of the composite type
+ OrtValue GetValue(int index, OrtAllocator allocator);
+
+ ///
+ /// Fetch string tensor element buffer pointer at the specified index,
+ /// convert/copy to UTF-16 char[] and return a ReadOnlyMemory{char} instance.
+ ///
+ /// Obtain TensorTypeAndShape to get shape and element count.
+ ///
+ /// flat string tensor element index
+ /// ReadOnlyMemory{char} backed by a managed char[]. Its lifespan is not
+ /// tied to the native buffer of OrtValue.
+ ReadOnlyMemory GetStringElementAsMemory(int index);
+
+ ///
+ /// Fetch string tensor element buffer pointer at the specified index,
+ /// copy/convert UTF-8 into a UTF-16 string and return it.
+ ///
+ /// Obtain TensorTypeAndShape to get shape and element count.
+ ///
+ /// flat string tensor element index
+ /// UTF-16 string instance
+ string GetStringElement(int index);
+
+ ///
+ /// Get a span over the native memory of the string tensor element.
+ /// The span is valid as long as the OrtValue is valid.
+ ///
+ /// This is useful if you want to perform your own UTF-8 decoding or
+ /// you do not care about decoding.
+ /// Obtain TensorTypeAndShape to get shape and element count.
+ ///
+ /// flat element index
+ /// ReadOnlySpan over UTF-8 bytes of the string tensor element
+ ReadOnlySpan GetStringElementAsSpan(int index);
+
+ ///
+ /// Convenience method to obtain all string tensor elements as a string array.
+ ///
+ /// string[]
+ ///
+ string[] GetStringTensorAsArray();
+ }
+
///
/// Represents a disposable OrtValue.
/// This class exposes a native instance of OrtValue.
@@ -44,7 +185,7 @@ public enum OnnxValueType
/// disposed properly, the pinned memory will continue to be pinned and interfere
/// with GC operation.
///
- public class OrtValue : IOrtValueOwner, IDisposable
+ public class OrtValue : IOrtValueOwner, IDisposable, IReadOnlyOrtValue
{
// OrtValues that are members of Sequences or Maps that map. They potentially map managed memory and we need to keep them around.
// this exists only when we deal with compose ML types.
@@ -52,11 +193,20 @@ public class OrtValue : IOrtValueOwner, IDisposable
private IntPtr _handle;
private MemoryHandle? _memHandle; // Present when the OrtValue is created on top of managed memory
private bool _disposed;
+ private bool _owned = true;
- internal OrtValue(IntPtr handle)
+ ///
+ /// Constructs OrtValue from a native handle. If `owned` is true, the OrtValue instance takes
+ /// ownership of the native handle and disposes it when the OrtValue instance is disposed.
+ ///
+ /// The native OrtValue handle.
+ /// True if this class instance owns the handle. If false, the handle
+ /// will not be released. Defaults to true.
+ internal OrtValue(IntPtr handle, bool owned = true)
{
_handle = handle;
InitOnnxType();
+ _owned = owned;
}
///
@@ -1464,7 +1614,10 @@ protected virtual void Dispose(bool disposing)
}
Debug.Assert(_handle != IntPtr.Zero);
- NativeMethods.OrtReleaseValue(_handle);
+ if (_owned)
+ {
+ NativeMethods.OrtReleaseValue(_handle);
+ }
_handle = IntPtr.Zero;
_disposed = true;
}
diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs
index f1eef57e03ea5..fe2cab57658c8 100644
--- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs
+++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs
@@ -21,104 +21,249 @@ public class CompileApiTests
[Fact]
public void BasicUsage()
{
- var so = new SessionOptions();
- using (var compileOptions = new OrtModelCompilationOptions(so))
+ using (var sessionOptions = new SessionOptions())
{
- // mainly checking these don't throw which ensures all the plumbing for the binding works.
- compileOptions.SetInputModelPath("model.onnx");
- compileOptions.SetOutputModelPath("compiled_model.onnx");
+ using (var compileOptions = new OrtModelCompilationOptions(sessionOptions))
+ {
+ // mainly checking these don't throw which ensures all the plumbing for the binding works.
+ compileOptions.SetInputModelPath("model.onnx");
+ compileOptions.SetOutputModelPath("compiled_model.onnx");
- compileOptions.SetOutputModelExternalInitializersFile("external_data.bin", 512);
- compileOptions.SetEpContextEmbedMode(true);
- compileOptions.SetGraphOptimizationLevel(GraphOptimizationLevel.ORT_ENABLE_BASIC);
+ compileOptions.SetOutputModelExternalInitializersFile("external_data.bin", 512);
+ compileOptions.SetEpContextEmbedMode(true);
+ compileOptions.SetGraphOptimizationLevel(GraphOptimizationLevel.ORT_ENABLE_BASIC);
- }
+ }
- // setup a new instance as SetOutputModelExternalInitializersFile is incompatible with SetOutputModelBuffer
- using (var compileOptions = new OrtModelCompilationOptions(so))
- {
- var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx");
- compileOptions.SetInputModelFromBuffer(model);
+ // setup a new instance as SetOutputModelExternalInitializersFile is incompatible with SetOutputModelBuffer
+ using (var compileOptions = new OrtModelCompilationOptions(sessionOptions))
+ {
+ var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx");
+ compileOptions.SetInputModelFromBuffer(model);
- // SetOutputModelBuffer updates the user provided IntPtr and size when it allocates data post-compile.
- // Due to that we need to allocate an IntPtr and UIntPtr here.
- IntPtr bytePtr = new IntPtr();
- UIntPtr bytesSize = new UIntPtr();
- var allocator = OrtAllocator.DefaultInstance;
- compileOptions.SetOutputModelBuffer(allocator, ref bytePtr, ref bytesSize);
- compileOptions.SetEpContextBinaryInformation("./", "squeezenet.onnx");
+ // SetOutputModelBuffer updates the user provided IntPtr and size when it allocates data post-compile.
+ // Due to that we need to allocate an IntPtr and UIntPtr here.
+ IntPtr bytePtr = new IntPtr();
+ UIntPtr bytesSize = new UIntPtr();
+ var allocator = OrtAllocator.DefaultInstance;
+ compileOptions.SetOutputModelBuffer(allocator, ref bytePtr, ref bytesSize);
+ compileOptions.SetEpContextBinaryInformation("./", "squeezenet.onnx");
- compileOptions.CompileModel();
+ compileOptions.CompileModel();
- Assert.NotEqual(IntPtr.Zero, bytePtr);
- Assert.NotEqual(UIntPtr.Zero, bytesSize);
+ Assert.NotEqual(IntPtr.Zero, bytePtr);
+ Assert.NotEqual(UIntPtr.Zero, bytesSize);
- byte[] compiledBytes = new byte[bytesSize.ToUInt64()];
- Marshal.Copy(bytePtr, compiledBytes, 0, (int)bytesSize.ToUInt32());
+ byte[] compiledBytes = new byte[bytesSize.ToUInt64()];
+ Marshal.Copy(bytePtr, compiledBytes, 0, (int)bytesSize.ToUInt32());
- // Check the compiled model is valid
- using (var session = new InferenceSession(compiledBytes, so))
- {
- Assert.NotNull(session);
+ // Check the compiled model is valid
+ using (var session = new InferenceSession(compiledBytes, sessionOptions))
+ {
+ Assert.NotNull(session);
+ }
+
+ allocator.FreeMemory(bytePtr);
}
- allocator.FreeMemory(bytePtr);
- }
+ // Test using OrtCompileApiFlags.ERROR_NO_NODES_COMPILED. A model compiled with CPU EP will not generate
+ // any compiled EPContext nodes, so expect an ORT_FAIL error.
+ using (var compileOptions = new OrtModelCompilationOptions(sessionOptions))
+ {
+ var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx");
+ var output_model_file = "should_not_generate.onnx";
+ compileOptions.SetInputModelFromBuffer(model);
+ compileOptions.SetOutputModelPath(output_model_file);
+ compileOptions.SetFlags(OrtCompileApiFlags.ERROR_IF_NO_NODES_COMPILED);
- // Test using OrtCompileApiFlags.ERROR_NO_NODES_COMPILED. A model compiled with CPU EP will not generate
- // any compiled EPContext nodes, so expect an ORT_FAIL error.
- using (var compileOptions = new OrtModelCompilationOptions(so))
- {
- var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx");
- var output_model_file = "should_not_generate.onnx";
- compileOptions.SetInputModelFromBuffer(model);
- compileOptions.SetOutputModelPath(output_model_file);
- compileOptions.SetFlags(OrtCompileApiFlags.ERROR_IF_NO_NODES_COMPILED);
+ // compile should fail
+ try
+ {
+ compileOptions.CompileModel();
+ Assert.Fail("CompileModel() should have thrown an exception");
+ }
+ catch (OnnxRuntimeException ex)
+ {
+ Assert.Contains("Unable to compile any nodes", ex.Message);
+ }
- // compile should fail
+ Assert.False(File.Exists(output_model_file)); // Output file should not be generated.
+ }
+
+ // Test using OrtCompileApiFlags.ERROR_IF_OUTPUT_FILE_EXISTS.
+ var outputModelFile = "squeezenet_ctx.onnx";
try
{
- compileOptions.CompileModel();
- Assert.Fail("CompileModel() should have thrown an exception");
+ using (var compileOptions = new OrtModelCompilationOptions(sessionOptions))
+ {
+ var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx");
+
+ // Compile and generate an output model.
+ compileOptions.SetInputModelFromBuffer(model);
+ compileOptions.SetOutputModelPath(outputModelFile);
+ compileOptions.CompileModel();
+ Assert.True(File.Exists(outputModelFile));
+
+ // Try to compile again with flag that prevents replacing an existing file.
+ // Expect failure.
+ compileOptions.SetFlags(OrtCompileApiFlags.ERROR_IF_OUTPUT_FILE_EXISTS);
+
+ // compile should fail
+ try
+ {
+ compileOptions.CompileModel();
+ Assert.Fail("CompileModel() should have thrown an exception");
+ }
+ catch (OnnxRuntimeException ex)
+ {
+ Assert.Contains("exists already", ex.Message);
+ }
+ }
}
- catch (OnnxRuntimeException ex)
+ finally
{
- Assert.Contains("Unable to compile any nodes", ex.Message);
+ if (File.Exists(outputModelFile))
+ {
+ // This file is created by ORT, so we delete it manually in finally block.
+ File.Delete(outputModelFile);
+ }
}
-
- Assert.False(File.Exists(output_model_file)); // Output file should not be generated.
}
+ }
+
+ [Fact]
+ public void WriteOutModelWithDelegate()
+ {
+ var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx");
+ var outputModelFilePath = "squeezenet_write_delegate_ctx.onnx";
- // Test using OrtCompileApiFlags.ERROR_IF_OUTPUT_FILE_EXISTS.
- using (var compileOptions = new OrtModelCompilationOptions(so))
+ using (FileStream fs = new FileStream(outputModelFilePath, FileMode.Create, FileAccess.Write, FileShare.None,
+ 4096, FileOptions.DeleteOnClose))
+ using (var sessionOptions = new SessionOptions())
+ using (var compileOptions = new OrtModelCompilationOptions(sessionOptions))
{
- var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx");
- var output_model_file = "squeezenet_ctx.onnx";
+ void BasicWriteBufferDelegate(ReadOnlySpan buffer)
+ {
+ Assert.True(buffer.Length > 0);
+ fs.Write(buffer.ToArray(), 0, buffer.Length); // Write it out to a file
+ }
// Compile and generate an output model.
compileOptions.SetInputModelFromBuffer(model);
- compileOptions.SetOutputModelPath(output_model_file);
+ compileOptions.SetOutputModelWriteDelegate(BasicWriteBufferDelegate);
compileOptions.CompileModel();
- Assert.True(File.Exists(output_model_file));
+ Assert.True(File.Exists(outputModelFilePath));
+ }
+ }
- // Try to compile again with flag that prevents replacing an existing file.
- // Expect failure.
- compileOptions.SetFlags(OrtCompileApiFlags.ERROR_IF_OUTPUT_FILE_EXISTS);
+ [Fact]
+ public void BasicGetInitializerLocationDelegate()
+ {
+ var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx");
+ var outputModelFilePath = "squeezenet_handle_initializer_delegate_ctx.onnx";
+ var initializersFilePath = "squeezenet_handle_initializer_delegate_ctx.bin";
- // compile should fail
- try
+ try
+ {
+ using (FileStream fs = new FileStream(initializersFilePath, FileMode.Create, FileAccess.Write,
+ FileShare.None, 4096, FileOptions.DeleteOnClose))
+ using (var sessionOptions = new SessionOptions())
+ using (var compileOptions = new OrtModelCompilationOptions(sessionOptions))
{
+ // Custom delegate that stores large initializers in a new file.
+ OrtExternalInitializerInfo BasicHandleInitializer(
+ string initializerName, IReadOnlyOrtValue initializerValue,
+ IReadOnlyExternalInitializerInfo originalInitializerLocation)
+ {
+ Assert.True(initializerName.Length > 0);
+
+ var byteSize = initializerValue.GetTensorSizeInBytes();
+ if (byteSize <= 64)
+ {
+ // Keep small initializers stored within model.
+ return null;
+ }
+
+ long byteOffset = fs.Position;
+ ReadOnlySpan dataSpan = initializerValue.GetTensorDataAsSpan();
+ fs.Write(dataSpan.ToArray(), 0, dataSpan.Length); // Write it out to a file
+
+ // Return the data's new location.
+ return new OrtExternalInitializerInfo(initializersFilePath, byteOffset, byteSize);
+ }
+
+ // Compile and generate an output model.
+ compileOptions.SetInputModelFromBuffer(model);
+ compileOptions.SetOutputModelPath(outputModelFilePath);
+ compileOptions.SetOutputModelGetInitializerLocationDelegate(BasicHandleInitializer);
compileOptions.CompileModel();
- Assert.Fail("CompileModel() should have thrown an exception");
+ Assert.True(File.Exists(outputModelFilePath));
}
- catch (OnnxRuntimeException ex)
+ }
+ finally
+ {
+ if (File.Exists(outputModelFilePath))
{
- Assert.Contains("exists already", ex.Message);
+ // This file is created by ORT, so we delete it manually in finally block.
+ File.Delete(outputModelFilePath);
}
+ }
+ }
+
+ [Fact]
+ public void GetInitializerLocationDelegateThatReusesExternalInitializers()
+ {
+ var model = TestDataLoader.LoadModelFromEmbeddedResource("conv_qdq_external_ini.onnx");
+ var outputModelFilePath = "conv_qdq_external_ini.reuse.ctx.onnx";
+ bool reusedExternalInitializers = false;
+
+ try
+ {
+ using (var sessionOptions = new SessionOptions())
+ using (var compileOptions = new OrtModelCompilationOptions(sessionOptions))
+ {
+ // Custom delegate that reuses the original external initializer file.
+ OrtExternalInitializerInfo ReuseExternalInitializers(
+ string initializerName, IReadOnlyOrtValue initializerValue,
+ IReadOnlyExternalInitializerInfo originalInitializerLocation)
+ {
+ Assert.True(initializerName.Length > 0);
+
+ if (originalInitializerLocation != null)
+ {
+ reusedExternalInitializers = true; // For test assertion only
+ string originalFilePath = originalInitializerLocation.GetFilePath();
+ long originalFileOffset = originalInitializerLocation.GetFileOffset();
+ long originalByteSize = originalInitializerLocation.GetByteSize();
+
+ Assert.True(originalFilePath.Length > 0);
+ Assert.True(originalFileOffset >= 0);
+ Assert.True(originalByteSize > 0);
- if (File.Exists(output_model_file))
+ // This initializer comes from an external file. Reuse it for compiled model.
+ return new OrtExternalInitializerInfo(originalFilePath, originalFileOffset, originalByteSize);
+ }
+
+ // Otherwise, embed initializers that were not originally external.
+ return null;
+ }
+
+ // Compile and generate an output model.
+ compileOptions.SetInputModelFromBuffer(model);
+ compileOptions.SetOutputModelPath(outputModelFilePath);
+ compileOptions.SetOutputModelGetInitializerLocationDelegate(ReuseExternalInitializers);
+ compileOptions.CompileModel();
+
+ Assert.True(File.Exists(outputModelFilePath));
+ Assert.True(reusedExternalInitializers);
+ }
+ }
+ finally
+ {
+ if (File.Exists(outputModelFilePath))
{
- File.Delete(output_model_file);
+ // This file is created by ORT, so we delete it manually in finally block.
+ File.Delete(outputModelFilePath);
}
}
}
diff --git a/csharp/testdata/conv_qdq_external_ini.bin b/csharp/testdata/conv_qdq_external_ini.bin
new file mode 100644
index 0000000000000..89eea0dba1fa4
Binary files /dev/null and b/csharp/testdata/conv_qdq_external_ini.bin differ
diff --git a/csharp/testdata/conv_qdq_external_ini.onnx b/csharp/testdata/conv_qdq_external_ini.onnx
new file mode 100644
index 0000000000000..c53e1f3ad4d9b
Binary files /dev/null and b/csharp/testdata/conv_qdq_external_ini.onnx differ
diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h
index 866892979b749..9a0708d72b4f8 100644
--- a/include/onnxruntime/core/graph/graph.h
+++ b/include/onnxruntime/core/graph/graph.h
@@ -1247,6 +1247,18 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
const std::filesystem::path& model_file_path,
const ModelSavingOptions& model_saving_options) const;
+ ///
+ /// Serialize the Graph to a onnx::GraphProto. Caller provides a function that determines where each initializer
+ /// is stored (i.e., either in an external file or within the model).
+ ///
+ /// Function called for every initializer.
+ /// Opaque user state passed to the handle_initializer_func.
+ /// Output parameter set to the serialized onnx::GraphProto.
+ /// A status indicating success or an error.
+ common::Status ToGraphProtoWithCustomInitializerHandling(OrtGetInitializerLocationFunc handle_initializer_func,
+ void* state,
+ /*out*/ ONNX_NAMESPACE::GraphProto& graph_proto) const;
+
/** Gets the ISchemaRegistry instances being used with this Graph. */
IOnnxRuntimeOpSchemaCollectionPtr GetSchemaRegistry() const;
@@ -1664,6 +1676,9 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
std::ostream& external_stream,
int64_t& external_offset) const;
+ Status ToGraphProtoWithCustomInitializerHandlingImpl(OrtGetInitializerLocationFunc handle_initializer_func,
+ void* state,
+ /*out*/ ONNX_NAMESPACE::GraphProto& output_graph_proto) const;
#endif
Version IrVersion() const noexcept {
diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h
index 5d0c273a218fe..81caf5069bb6e 100644
--- a/include/onnxruntime/core/session/onnxruntime_c_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_c_api.h
@@ -535,6 +535,57 @@ typedef OrtStatus*(ORT_API_CALL* EpSelectionDelegate)(_In_ const OrtEpDevice** e
_Out_ size_t* num_selected,
_In_ void* state);
+/** \brief Function called by ORT to write a buffer to a custom destination (e.g., file, stream, etc.).
+ *
+ * \param state Opaque pointer holding the user's state.
+ * \param buffer The buffer to write.
+ * \param buffer_num_bytes The size of the buffer in bytes.
+ *
+ * \return OrtStatus* Write status. Return nullptr on success.
+ * Use CreateStatus to provide error info. Use ORT_FAIL as the error code.
+ * ORT will release the OrtStatus* if not null.
+ */
+typedef OrtStatus*(ORT_API_CALL* OrtWriteBufferFunc)(_In_ void* state,
+ _In_ const void* buffer,
+ _In_ size_t buffer_num_bytes);
+
+/** \brief Function called by ORT to allow user to specify how an initializer should be saved, that is, either
+ * written to an external file or stored within the model. ORT calls this function for every initializer when
+ * generating a model.
+ *
+ * If the function implementation sets the `new_external_info` output parameter to NULL, ORT stores the initializer data
+ * within the generated model.
+ *
+ * Otherwise, if the function implementation sets `new_external_info` to a valid OrtExternalInitializerInfo instance,
+ * ORT assumes that this function stores the initializer data in a file. In this case, ORT configures the model's
+ * initializer to point to the location specified by the `new_external_info` output parameter.
+ *
+ * \param[in] state Opaque pointer holding the user's state.
+ * \param[in] initializer_name The initializer's name as a null-terminated string.
+ * \param[in] initializer_value OrtValue containing the initializer's data, type, and shape.
+ * \param[in] external_info If the initializer is originally stored in an external file, `external_info` contains
+ * the file path, file offset, and the data's byte size within the file. Otherwise,
+ * `external_info` is NULL if the initializer is not originally stored in a file.
+ * \param[out] new_external_info Output parameter set to a new OrtExternalInitializerInfo instance indicating the
+ * location where the function implementation stored the initializer data.
+ * The function implementation must use `OrtApi::CreateExternalInitializerInfo()` to
+ * create the instance.
+ * If the function implementation sets `new_external_info` to NULL,
+ * ORT stores the initializers within the model.
+ *
+ * \note ORT takes ownership of the `new_external_info` output parameter.
+ *
+ * \return OrtStatus* Write status. Return nullptr on success.
+ * Use CreateStatus to provide error info. Use ORT_FAIL as the error code.
+ * ORT will release the OrtStatus* if not null.
+ */
+typedef OrtStatus*(ORT_API_CALL* OrtGetInitializerLocationFunc)(
+ _In_ void* state,
+ _In_ const char* initializer_name,
+ _In_ const OrtValue* initializer_value,
+ _In_opt_ const OrtExternalInitializerInfo* external_info,
+ _Outptr_result_maybenull_ OrtExternalInitializerInfo** new_external_info);
+
/** \brief Algorithm to use for cuDNN Convolution Op
*/
typedef enum OrtCudnnConvAlgoSearch {
@@ -6509,6 +6560,26 @@ struct OrtApi {
_In_ size_t num_ep_devices,
_In_ const char* compatibility_info,
_Out_ OrtCompiledModelCompatibility* out_status);
+
+ /// \name OrtExternalInitializerInfo
+ /// @{
+
+ /** \brief Creates an OrtExternalInitializerInfo instance.
+ *
+ * \param[in] filepath The relative path to the file that stores the initializer's data. ORT copies this path string.
+ * \param[in] file_offset The byte offset where the initializer's data is stored within the file.
+ * \param[in] byte_size The size in bytes of the initializer's data within the file.
+ * \param[out] out Output parameter set to the new OrtExternalInitializerInfo instance.
+ * Must be released by calling ReleaseExternalInitializerInfo().
+ *
+ * \snippet{doc} snippets.dox OrtStatus Return Value
+ *
+ * \since Version 1.23.
+ */
+ ORT_API2_STATUS(CreateExternalInitializerInfo, _In_ const ORTCHAR_T* filepath, _In_ int64_t file_offset,
+ _In_ size_t byte_size, _Outptr_ OrtExternalInitializerInfo** out);
+
+ /// @}
};
/*
@@ -7267,6 +7338,43 @@ struct OrtCompileApi {
ORT_API2_STATUS(ModelCompilationOptions_SetGraphOptimizationLevel,
_In_ OrtModelCompilationOptions* model_compile_options,
_In_ GraphOptimizationLevel graph_optimization_level);
+
+ /** \brief Sets a OrtWriteBufferFunc function that is called by ORT to write out the output model's serialized
+ * ONNX bytes.
+ *
+ * The provided write function may be called repeatedly until then entire output model has been written out. Each call
+ * to the write function is expected to consume the entire input buffer.
+ *
+ * The output model's destination (e.g., file path, memory buffer, or stream) can be set with any of the functions
+ * that begin with ModelCompilationOptions_SetOutputModel____.
+ *
+ * \param[in] model_compile_options The OrtModelCompilationOptions instance.
+ * \param[in] write_func The OrtWriteBufferFunc function called by ORT when writing out the model.
+ * \param[in] state Opaque state passed as the first argument to OrtWriteBufferFunc. Can be NULL.
+ *
+ * \snippet{doc} snippets.dox OrtStatus Return Value
+ *
+ * \since Version 1.23.
+ */
+ ORT_API2_STATUS(ModelCompilationOptions_SetOutputModelWriteFunc,
+ _In_ OrtModelCompilationOptions* model_compile_options,
+ _In_ OrtWriteBufferFunc write_func, _In_ void* state);
+
+ /** \brief Sets a OrtGetInitializerLocationFunc function that is called by ORT for every initializer in the generated
+ * model. Allows implementer to specify whether initializers should be stored within the model or externally.
+ *
+ * \param[in] model_compile_options The OrtModelCompilationOptions instance.
+ * \param[in] get_initializer_location_func The OrtGetInitializerLocationFunc function called by ORT when
+ * to determine the location of the initializer.
+ * \param[in] state Opaque state passed as the first argument to OrtGetInitializerLocationFunc. Can be NULL.
+ *
+ * \snippet{doc} snippets.dox OrtStatus Return Value
+ *
+ * \since Version 1.23.
+ */
+ ORT_API2_STATUS(ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc,
+ _In_ OrtModelCompilationOptions* model_compile_options,
+ _In_ OrtGetInitializerLocationFunc get_initializer_location_func, _In_ void* state);
};
/*
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
index af0f5046a3f9f..4a8c67e2215ec 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
@@ -904,6 +904,9 @@ struct ConstExternalInitializerInfoImpl : Base {
using ConstExternalInitializerInfo =
detail::ConstExternalInitializerInfoImpl>;
+/** \brief Wrapper around ::OrtExternalInitializerInfo
+ *
+ */
struct ExternalInitializerInfo : detail::ConstExternalInitializerInfoImpl {
using Base = detail::ConstExternalInitializerInfoImpl;
using Base::Base;
@@ -913,6 +916,13 @@ struct ExternalInitializerInfo : detail::ConstExternalInitializerInfoImpl{p} {}
ConstExternalInitializerInfo GetConst() const { return ConstExternalInitializerInfo{this->p_}; }
+
+ ///< Wraps OrtApi::CreateExternalInitializerInfo
+ ExternalInitializerInfo(const ORTCHAR_T* filepath, int64_t file_offset, size_t byte_size);
+
+ ///< Wrapper around CreateExternalInitializerInfo that does not throw an exception.
+ static Status Create(const ORTCHAR_T* filepath, int64_t file_offset, size_t byte_size,
+ /*out*/ ExternalInitializerInfo& out);
};
namespace detail {
@@ -1454,8 +1464,18 @@ struct ModelCompilationOptions : detail::Base {
ModelCompilationOptions& SetOutputModelPath(const ORTCHAR_T* output_model_path); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelPath
ModelCompilationOptions& SetOutputModelExternalInitializersFile(const ORTCHAR_T* file_path,
size_t initializer_size_threshold); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelExternalInitializersFile
+
+ ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc
+ ModelCompilationOptions& SetOutputModelGetInitializerLocationFunc(
+ OrtGetInitializerLocationFunc get_initializer_location_func,
+ void* state);
+
ModelCompilationOptions& SetOutputModelBuffer(OrtAllocator* allocator, void** output_model_buffer_ptr,
size_t* output_model_buffer_size_ptr); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelBuffer
+
+ ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelWriteFunc
+ ModelCompilationOptions& SetOutputModelWriteFunc(OrtWriteBufferFunc write_func, void* state);
+
ModelCompilationOptions& SetEpContextBinaryInformation(const ORTCHAR_T* output_directory,
const ORTCHAR_T* model_name); ///< Wraps OrtApi::ModelCompilationOptions_SetEpContextBinaryInformation
ModelCompilationOptions& SetFlags(uint32_t flags); ///< Wraps OrtApi::ModelCompilationOptions_SetFlags
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
index 30ff4753f42a8..59979189eed0f 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
@@ -589,6 +589,24 @@ inline size_t ConstExternalInitializerInfoImpl::GetByteSize() const {
}
} // namespace detail
+inline ExternalInitializerInfo::ExternalInitializerInfo(const ORTCHAR_T* filepath, int64_t file_offset,
+ size_t byte_size) {
+ ThrowOnError(GetApi().CreateExternalInitializerInfo(filepath, file_offset, byte_size, &this->p_));
+}
+
+inline Status ExternalInitializerInfo::Create(const ORTCHAR_T* filepath, int64_t file_offset, size_t byte_size,
+ /*out*/ ExternalInitializerInfo& out) {
+ OrtExternalInitializerInfo* info = nullptr;
+ OrtStatus* status = GetApi().CreateExternalInitializerInfo(filepath, file_offset, byte_size, &info);
+ if (status != nullptr) {
+ return Status{status};
+ }
+
+ out = ExternalInitializerInfo(info);
+
+ return Status{nullptr};
+}
+
namespace detail {
template
inline const char* KeyValuePairsImpl::GetValue(const char* key) const {
@@ -1021,6 +1039,16 @@ inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelExternalI
return *this;
}
+inline ModelCompilationOptions&
+ModelCompilationOptions::SetOutputModelGetInitializerLocationFunc(
+ OrtGetInitializerLocationFunc get_initializer_location_func, void* state) {
+ Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc(
+ this->p_,
+ get_initializer_location_func,
+ state));
+ return *this;
+}
+
inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelBuffer(
OrtAllocator* allocator, void** output_model_buffer_ptr, size_t* output_model_buffer_size_ptr) {
Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetOutputModelBuffer(this->p_, allocator,
@@ -1029,6 +1057,12 @@ inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelBuffer(
return *this;
}
+inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelWriteFunc(OrtWriteBufferFunc write_func,
+ void* state) {
+ Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetOutputModelWriteFunc(this->p_, write_func, state));
+ return *this;
+}
+
inline ModelCompilationOptions& ModelCompilationOptions::SetEpContextEmbedMode(
bool embed_ep_context_in_model) {
Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetEpContextEmbedMode(
diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py
index 38f034cf6d266..c0fe171f76037 100644
--- a/onnxruntime/__init__.py
+++ b/onnxruntime/__init__.py
@@ -33,6 +33,7 @@
OrtCompileApiFlags, # noqa: F401
OrtEpDevice, # noqa: F401
OrtExecutionProviderDevicePolicy, # noqa: F401
+ OrtExternalInitializerInfo, # noqa: F401
OrtHardwareDevice, # noqa: F401
OrtHardwareDeviceType, # noqa: F401
OrtMemoryInfo, # noqa: F401
diff --git a/onnxruntime/core/framework/ep_context_options.cc b/onnxruntime/core/framework/ep_context_options.cc
new file mode 100644
index 0000000000000..abfd3cf89cecf
--- /dev/null
+++ b/onnxruntime/core/framework/ep_context_options.cc
@@ -0,0 +1,69 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include
+#include
+#include
+#include
+#include "core/common/common.h"
+#include "core/framework/ep_context_options.h"
+#include "core/session/onnxruntime_session_options_config_keys.h"
+
+namespace onnxruntime {
+namespace epctx {
+// class ModelGenOptions
+
+ModelGenOptions::ModelGenOptions() = default;
+
+ModelGenOptions::ModelGenOptions(const ConfigOptions& config_options) {
+ enable = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1";
+
+ std::string output_model_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, "");
+ if (!output_model_path.empty()) {
+ output_model_location = std::filesystem::path(output_model_path);
+ } else {
+ output_model_location = std::monostate{};
+ }
+
+ std::string external_initializers_file_path = config_options.GetConfigOrDefault(
+ kOrtSessionOptionsEpContextModelExternalInitializersFileName, "");
+ if (!external_initializers_file_path.empty()) {
+ ExternalInitializerFileInfo ext_info = {};
+ ext_info.file_path = external_initializers_file_path;
+ ext_info.size_threshold = 0;
+ initializers_location = std::move(ext_info);
+ }
+
+ embed_ep_context_in_model = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "0") == "1";
+}
+
+bool ModelGenOptions::HasOutputModelLocation() const {
+ return !std::holds_alternative(output_model_location);
+}
+
+const std::filesystem::path* ModelGenOptions::TryGetOutputModelPath() const {
+ return std::get_if(&output_model_location);
+}
+
+const BufferHolder* ModelGenOptions::TryGetOutputModelBuffer() const {
+ return std::get_if(&output_model_location);
+}
+
+const BufferWriteFuncHolder* ModelGenOptions::TryGetOutputModelWriteFunc() const {
+ return std::get_if(&output_model_location);
+}
+
+bool ModelGenOptions::AreInitializersEmbeddedInOutputModel() const {
+ return std::holds_alternative(initializers_location);
+}
+
+const ExternalInitializerFileInfo* ModelGenOptions::TryGetExternalInitializerFileInfo() const {
+ return std::get_if(&initializers_location);
+}
+
+const InitializerHandler* ModelGenOptions::TryGetInitializerHandler() const {
+ return std::get_if(&initializers_location);
+}
+
+} // namespace epctx
+} // namespace onnxruntime
diff --git a/onnxruntime/core/framework/ep_context_options.h b/onnxruntime/core/framework/ep_context_options.h
new file mode 100644
index 0000000000000..6643516bfb4c3
--- /dev/null
+++ b/onnxruntime/core/framework/ep_context_options.h
@@ -0,0 +1,98 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include
+#include
+#include "core/framework/allocator.h"
+#include "core/framework/config_options.h"
+
+namespace onnxruntime {
+namespace epctx {
+///
+/// Holds the buffer that will store the output model and the allocator used to allocate the memory.
+///
+struct BufferHolder {
+ void** buffer_ptr = nullptr;
+ size_t* buffer_size_ptr = nullptr;
+ AllocatorPtr buffer_allocator = nullptr;
+};
+
+///
+/// Holds the opaque stream state and the write function that ORT calls to write out the output model.
+///
+struct BufferWriteFuncHolder {
+ OrtWriteBufferFunc write_func = nullptr;
+ void* stream_state = nullptr; // Opaque pointer to user's stream state. Passed as first argument to write_func.
+};
+
+///
+/// Holds path and size threshold used to write out initializers to an external file.
+///
+struct ExternalInitializerFileInfo {
+ std::filesystem::path file_path;
+ size_t size_threshold = 0;
+};
+
+///
+/// Holds function and state provided by user to handle initializer data (i.e., write to stream or embed in model).
+///
+struct InitializerHandler {
+ OrtGetInitializerLocationFunc handle_initializer_func = nullptr;
+ void* state = nullptr;
+};
+
+///
+/// Stores EPContext model generation options. Used in SessionOptions.
+///
+struct ModelGenOptions {
+ // Action to take if the output model does not have compiled (EPContext) nodes.
+ enum class ActionIfNoCompiledNodes {
+ // Return OK() but don't generate an output model. Compiling via SessionOptions defaults to this behavior
+ // to maintain compatibility. The explicit compile API does *not* use this action.
+ kDontGenerateModel = 0,
+
+ // Generate an output model even if it doesn't have compiled nodes.
+ // The explicit Compile API defaults to this value.
+ kGenerateModel,
+
+ // Return an error if the model does not have compiled nodes.
+ // The explicit Compile API can be configured to this value.
+ kReturnError,
+ };
+
+ ModelGenOptions();
+
+ // Initializes from string key/value pairs in session config options.
+ explicit ModelGenOptions(const ConfigOptions& config_options);
+
+ bool enable = false;
+ bool error_if_output_file_exists = true;
+ bool error_if_no_compiled_nodes = false;
+ bool embed_ep_context_in_model = false;
+ ActionIfNoCompiledNodes action_if_no_compiled_nodes = ActionIfNoCompiledNodes::kDontGenerateModel;
+
+ std::variant // Function to write the output model to a user's stream.
+ output_model_location = std::monostate{};
+
+ std::variant // Custom function called for every initializer to determine location.
+ initializers_location = std::monostate{};
+
+ bool HasOutputModelLocation() const;
+ const std::filesystem::path* TryGetOutputModelPath() const;
+ const BufferHolder* TryGetOutputModelBuffer() const;
+ const BufferWriteFuncHolder* TryGetOutputModelWriteFunc() const;
+
+ bool AreInitializersEmbeddedInOutputModel() const;
+ const ExternalInitializerFileInfo* TryGetExternalInitializerFileInfo() const;
+ const InitializerHandler* TryGetInitializerHandler() const;
+};
+
+} // namespace epctx
+} // namespace onnxruntime
diff --git a/onnxruntime/core/framework/ep_context_utils.cc b/onnxruntime/core/framework/ep_context_utils.cc
new file mode 100644
index 0000000000000..3f02c54538526
--- /dev/null
+++ b/onnxruntime/core/framework/ep_context_utils.cc
@@ -0,0 +1,126 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#if !defined(ORT_MINIMAL_BUILD)
+#include
+#include
+#include "core/framework/ep_context_utils.h"
+#include "core/framework/error_code_helper.h"
+#include "core/graph/model_saving_options.h"
+
+namespace onnxruntime {
+namespace epctx {
+
+// Serialize an EPContext model into a onnx::ModelProto.
+Status EpContextModelToProto(const onnxruntime::Model& ep_context_model,
+ const std::filesystem::path& validated_model_path,
+ const epctx::ModelGenOptions& ep_context_gen_options,
+ /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) {
+ // Handle case where initializers are stored inline within the ONNX model.
+ if (ep_context_gen_options.AreInitializersEmbeddedInOutputModel()) {
+ // if no external ini file specified, set force_embed_external_ini to true to avoid intermediate file creation
+ // and force all initializers embed into the ONNX file.
+ ModelSavingOptions model_saving_options{/*size_threshold*/ SIZE_MAX};
+ model_saving_options.force_embed_external_ini = true;
+
+ model_proto = ep_context_model.ToGraphProtoWithExternalInitializers(std::filesystem::path{},
+ validated_model_path,
+ model_saving_options);
+ return Status::OK();
+ }
+
+ // Handle case where initializers (with size > threshold) are stored in an external file.
+ if (const epctx::ExternalInitializerFileInfo* ext_info = ep_context_gen_options.TryGetExternalInitializerFileInfo();
+ ext_info != nullptr) {
+ ModelSavingOptions model_saving_options{ext_info->size_threshold};
+
+ model_proto = ep_context_model.ToGraphProtoWithExternalInitializers(ext_info->file_path,
+ validated_model_path,
+ model_saving_options);
+ return Status::OK();
+ }
+
+ // Handle case where user specified a custom handler function that determines how each initializer is saved.
+ if (const epctx::InitializerHandler* custom_handler = ep_context_gen_options.TryGetInitializerHandler();
+ custom_handler != nullptr) {
+ ORT_RETURN_IF_ERROR(ep_context_model.ToGraphProtoWithCustomInitializerHandling(
+ custom_handler->handle_initializer_func,
+ custom_handler->state,
+ model_proto));
+ return Status::OK();
+ }
+
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unexpected location for initializers while generating ",
+ validated_model_path);
+}
+
+//
+// OutStreamBuf class:
+//
+
+OutStreamBuf::OutStreamBuf(BufferWriteFuncHolder write_func_holder)
+ : write_func_holder_(write_func_holder), buffer_(65536) {
+ setp(buffer_.data(), buffer_.data() + buffer_.size());
+}
+
+OutStreamBuf::~OutStreamBuf() {
+ sync();
+}
+
+// Called when the buffer_ is full. Flushes the buffer_ (via sync()) and then writes the overflow character to buffer_.
+std::streambuf::int_type OutStreamBuf::overflow(std::streambuf::int_type ch) {
+ if (sync() == -1) {
+ return traits_type::eof();
+ }
+
+ if (ch != traits_type::eof()) {
+ *pptr() = static_cast(ch);
+ pbump(1);
+ }
+
+ return ch;
+}
+
+// Flushes the entire buffer_ to the user's write function.
+int OutStreamBuf::sync() {
+ if (!last_status_.IsOK()) {
+ return -1;
+ }
+
+ std::ptrdiff_t num_bytes = pptr() - pbase();
+ if (num_bytes == 0) {
+ return 0;
+ }
+
+ // Can only call pbump() with an int, so can only write at most (2^31 - 1) bytes.
+ if (num_bytes > std::numeric_limits::max()) {
+ num_bytes = std::numeric_limits::max();
+ }
+
+ char* ptr = pbase();
+
+ Status status = Status::OK();
+
+ ORT_TRY {
+ status = ToStatusAndRelease(write_func_holder_.write_func(write_func_holder_.stream_state,
+ ptr, num_bytes));
+ }
+ ORT_CATCH(const std::exception& e) {
+ ORT_HANDLE_EXCEPTION([&]() {
+ status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
+ "Caught exception while calling user's OrtOutStreamWriteFunc callback: ", e.what());
+ });
+ }
+
+ if (!status.IsOK()) {
+ last_status_ = std::move(status);
+ return -1;
+ }
+
+ pbump(-static_cast(num_bytes)); // Reset internal pointer to point to the beginning of the buffer_
+ return 0;
+}
+
+} // namespace epctx
+} // namespace onnxruntime
+#endif // !defined(ORT_MINIMAL_BUILD)
diff --git a/onnxruntime/core/framework/ep_context_utils.h b/onnxruntime/core/framework/ep_context_utils.h
new file mode 100644
index 0000000000000..b3c76565982ff
--- /dev/null
+++ b/onnxruntime/core/framework/ep_context_utils.h
@@ -0,0 +1,61 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#if !defined(ORT_MINIMAL_BUILD)
+
+#include
+#include
+#include
+
+#include "core/common/status.h"
+#include "core/framework/ep_context_options.h"
+#include "core/graph/model.h"
+
+namespace onnxruntime {
+namespace epctx {
+
+///
+/// Serialize an EPContext model into a onnx::ModelProto based on the provided options.
+///
+/// The EP Context model to serialize.
+/// The path into which to save the model. May be empty if serialized into a
+/// buffer or output stream.
+/// The model generation options.
+/// Output parameter set to the serialized onnx::ModelProto.
+/// A status indicating success or an error.
+Status EpContextModelToProto(const onnxruntime::Model& ep_context_model,
+ const std::filesystem::path& validated_model_path,
+ const epctx::ModelGenOptions& ep_context_gen_options,
+ /*out*/ ONNX_NAMESPACE::ModelProto& model_proto);
+
+// Class that wraps the user's OrtBufferWriteFunc function to enable use with
+// C++'s std::ostream.
+// Example:
+// BufferWriteFuncHolder write_func_holder{write_func, stream_state};
+// std::unique_ptr out_stream_buf = std::make_unique(write_func_holder);
+// std::ostream out_stream(out_stream_buf.get());
+class OutStreamBuf : public std::streambuf {
+ public:
+ explicit OutStreamBuf(BufferWriteFuncHolder write_func_holder);
+ ~OutStreamBuf();
+
+ const Status& GetStatus() const {
+ return last_status_;
+ }
+
+ protected:
+ int_type overflow(int_type ch) override;
+ int sync() override;
+
+ private:
+ BufferWriteFuncHolder write_func_holder_{};
+ std::vector buffer_;
+ Status last_status_{};
+};
+
+} // namespace epctx
+} // namespace onnxruntime
+
+#endif // !defined(ORT_MINIMAL_BUILD)
diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc
index 421e5a6db51b7..43caf4766d5c0 100644
--- a/onnxruntime/core/framework/graph_partitioner.cc
+++ b/onnxruntime/core/framework/graph_partitioner.cc
@@ -5,10 +5,12 @@
#include
#include
+#include
#include "core/common/inlined_containers.h"
#include "core/common/string_utils.h"
#include "core/framework/compute_capability.h"
+#include "core/framework/ep_context_utils.h"
#include "core/framework/execution_providers.h"
#include "core/framework/func_kernel.h"
#include "core/framework/kernel_lookup.h"
@@ -20,9 +22,9 @@
#include "core/graph/graph_utils.h"
#include "core/graph/graph_viewer.h"
#include "core/graph/model.h"
-#include "core/graph/model_saving_options.h"
-#include "core/session/onnxruntime_session_options_config_keys.h"
#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h"
+#include "core/session/onnxruntime_session_options_config_keys.h"
+#include "core/util/protobuf_parsing_utils.h"
// uncomment this line to count non-CUDA ops in ONNX domain
// #define COUNT_NON_CUDA_OPS
@@ -766,6 +768,7 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide
}
// Validate the ep_context_path to make sure it is file path and check whether the file exist already
+// TODO: Move function to ep_context_utils.h/cc
static Status GetValidatedEpContextPath(const std::filesystem::path& ep_context_path,
const std::filesystem::path& model_path,
std::filesystem::path& context_cache_path,
@@ -794,9 +797,10 @@ static Status GetValidatedEpContextPath(const std::filesystem::path& ep_context_
return Status::OK();
}
+// TODO: Move function to ep_context_utils.h/cc
static Status CreateEpContextModel(const ExecutionProviders& execution_providers,
const Graph& graph,
- const EpContextModelGenerationOptions& ep_context_gen_options,
+ const epctx::ModelGenOptions& ep_context_gen_options,
const logging::Logger& logger) {
InlinedVector all_ep_context_nodes;
for (const auto& ep : execution_providers) {
@@ -807,11 +811,11 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers
if (all_ep_context_nodes.size() < 1) {
auto action_if_no_compiled_nodes = ep_context_gen_options.action_if_no_compiled_nodes;
- ORT_RETURN_IF(action_if_no_compiled_nodes == EpContextModelGenerationOptions::ActionIfNoCompiledNodes::kReturnError,
+ ORT_RETURN_IF(action_if_no_compiled_nodes == epctx::ModelGenOptions::ActionIfNoCompiledNodes::kReturnError,
"Unable to compile any nodes. Check that the session EPs support compilation and can execute "
"at least one subgraph in the model.");
- if (action_if_no_compiled_nodes == EpContextModelGenerationOptions::ActionIfNoCompiledNodes::kDontGenerateModel) {
+ if (action_if_no_compiled_nodes == epctx::ModelGenOptions::ActionIfNoCompiledNodes::kDontGenerateModel) {
LOGS(logger, WARNING) << "Unable to compile any nodes. ONNX Runtime will not generate a compiled model. "
"Either the session EPs do not support compilation or the model is already compiled.";
// Note: this path is only taken if a model is compiled with the original compilation approach that uses
@@ -821,7 +825,7 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers
}
// Assert so that this is caught in a test in DEBUG builds (in case a new enum value is added)
- assert(action_if_no_compiled_nodes == EpContextModelGenerationOptions::ActionIfNoCompiledNodes::kGenerateModel);
+ assert(action_if_no_compiled_nodes == epctx::ModelGenOptions::ActionIfNoCompiledNodes::kGenerateModel);
LOGS(logger, INFO) << "Unable to compile any nodes but will still generate an output model. "
"Either the session EPs do not support compilation or the model is already compiled.";
}
@@ -835,15 +839,17 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers
return std::make_pair(false, static_cast(nullptr));
};
- bool saving_to_buffer = ep_context_gen_options.output_model_buffer_ptr != nullptr &&
- ep_context_gen_options.output_model_buffer_size_ptr != nullptr &&
- ep_context_gen_options.output_model_buffer_allocator != nullptr;
+ const epctx::BufferHolder* output_buffer_holder = ep_context_gen_options.TryGetOutputModelBuffer();
+ const epctx::BufferWriteFuncHolder* output_write_func_holder = ep_context_gen_options.TryGetOutputModelWriteFunc();
+ const std::filesystem::path* output_model_path_ptr = ep_context_gen_options.TryGetOutputModelPath();
- std::filesystem::path context_cache_path;
- if (!saving_to_buffer || !graph.ModelPath().empty()) {
- ORT_RETURN_IF_ERROR(GetValidatedEpContextPath(ep_context_gen_options.output_model_file_path,
+ std::filesystem::path valid_output_model_path;
+ if (output_model_path_ptr != nullptr || !graph.ModelPath().empty()) {
+ std::filesystem::path output_model_path = (output_model_path_ptr != nullptr) ? *output_model_path_ptr
+ : std::filesystem::path("");
+ ORT_RETURN_IF_ERROR(GetValidatedEpContextPath(output_model_path,
graph.ModelPath(),
- context_cache_path,
+ valid_output_model_path,
ep_context_gen_options.error_if_output_file_exists));
}
@@ -910,10 +916,11 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers
}
}
+ ORT_RETURN_IF_ERROR(ep_graph.Resolve());
+
// Generate EP compatibility strings for OrtEp types and add to model metadata
// At this point, the graph has been populated with all the EPContext nodes
{
- ORT_RETURN_IF_ERROR(ep_graph.Resolve());
const GraphViewer graph_viewer(ep_graph);
for (const auto& ep : execution_providers) {
try {
@@ -938,39 +945,60 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers
}
}
- size_t ini_size_threshold = ep_context_gen_options.output_external_initializer_size_threshold;
- std::filesystem::path external_ini_path = ep_context_gen_options.output_external_initializers_file_path;
- bool force_embed_external_ini = false;
- if (external_ini_path.empty()) {
- // if no external ini file specified, set force_embed_external_ini to true to avoid intermedia file creation
- // and force all initializers embed into the Onnx file
- ini_size_threshold = SIZE_MAX;
- force_embed_external_ini = true;
- }
-
- ModelSavingOptions model_saving_options{ini_size_threshold};
- model_saving_options.force_embed_external_ini = force_embed_external_ini;
+ ONNX_NAMESPACE::ModelProto model_proto;
+ ORT_RETURN_IF_ERROR(EpContextModelToProto(ep_context_model, valid_output_model_path, ep_context_gen_options,
+ /*out*/ model_proto));
- if (saving_to_buffer) {
- ORT_RETURN_IF_ERROR(ep_context_model.MainGraph().Resolve());
- // TODO(adrianlizarraga): Investigate if we can make this more memory efficient.
- // May be able to use allocator to directly allocate the ModelProto to avoid a copy.
- ONNX_NAMESPACE::ModelProto model_proto = ep_context_model.ToGraphProtoWithExternalInitializers(external_ini_path,
- context_cache_path,
- model_saving_options);
+ if (output_buffer_holder != nullptr) {
+ // Write output model into a buffer ORT allocates for the user.
size_t buffer_size = model_proto.ByteSizeLong();
ORT_RETURN_IF(buffer_size > static_cast(std::numeric_limits::max()),
"Cannot serialize ONNX ModelProto larger than 2GB");
- AllocatorPtr allocator = ep_context_gen_options.output_model_buffer_allocator;
+ AllocatorPtr allocator = output_buffer_holder->buffer_allocator;
IAllocatorUniquePtr buffer = IAllocator::MakeUniquePtr(allocator, buffer_size);
model_proto.SerializeToArray(buffer.get(), static_cast(buffer_size));
- *ep_context_gen_options.output_model_buffer_size_ptr = buffer_size;
- *ep_context_gen_options.output_model_buffer_ptr = buffer.release();
+ *output_buffer_holder->buffer_size_ptr = buffer_size;
+ *output_buffer_holder->buffer_ptr = buffer.release();
+ } else if (output_write_func_holder != nullptr) {
+ // Write output model to user's output stream.
+ size_t buffer_size = model_proto.ByteSizeLong();
+ ORT_RETURN_IF(buffer_size > static_cast(std::numeric_limits::max()),
+ "Cannot serialize ONNX ModelProto larger than 2GB");
+
+ auto out_stream_buf = std::make_unique(*output_write_func_holder);
+ std::ostream out_stream(out_stream_buf.get());
+
+ model_proto.SerializeToOstream(&out_stream);
+ out_stream.flush();
+ ORT_RETURN_IF_ERROR(out_stream_buf->GetStatus());
} else {
- ORT_RETURN_IF_ERROR(Model::SaveWithExternalInitializers(ep_context_model, context_cache_path,
- external_ini_path, model_saving_options));
+ // Write output model to a file.
+ int fd = 0;
+ Status status = Env::Default().FileOpenWr(valid_output_model_path, fd);
+ ORT_RETURN_IF_ERROR(status);
+
+ ORT_TRY {
+ google::protobuf::io::FileOutputStream output(fd);
+ bool serialize_result = model_proto.SerializeToZeroCopyStream(&output) && output.Flush();
+ if (!serialize_result) {
+ status = ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_PROTOBUF,
+ "Protobuf serialization failed when generating EPContext model ",
+ valid_output_model_path);
+ }
+ }
+ ORT_CATCH(const std::exception& ex) {
+ ORT_HANDLE_EXCEPTION([&]() {
+ status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, ex.what());
+ });
+ }
+ if (!status.IsOK()) {
+ GSL_SUPPRESS(es .84)
+ ORT_IGNORE_RETURN_VALUE(Env::Default().FileClose(fd));
+ return status;
+ }
+ ORT_RETURN_IF_ERROR(Env::Default().FileClose(fd));
}
return Status::OK();
@@ -1221,7 +1249,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr,
const ConfigOptions& config_options,
const logging::Logger& logger,
Mode mode,
- const EpContextModelGenerationOptions& ep_context_gen_options,
+ const epctx::ModelGenOptions& ep_context_gen_options,
const layout_transformation::DebugGraphFn& debug_graph_fn) const {
// It is a greedy partitioning algorithm per provider preferences user provided when calling ONNX RUNTIME right now.
// 1. Execution providers' capabilities are checked one by one.
@@ -1268,12 +1296,15 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr,
if (mode == Mode::kNormal || mode == Mode::kAssignOnly) {
#if !defined(ORT_MINIMAL_BUILD)
- if (ep_context_gen_options.enable && ep_context_gen_options.output_model_buffer_ptr == nullptr) {
- // Check before EP compile graphs
- std::filesystem::path context_cache_path;
- ORT_RETURN_IF_ERROR(GetValidatedEpContextPath(ep_context_gen_options.output_model_file_path, graph.ModelPath(),
- context_cache_path,
- ep_context_gen_options.error_if_output_file_exists));
+ if (ep_context_gen_options.enable) {
+ if (const std::filesystem::path* output_model_path_ptr = ep_context_gen_options.TryGetOutputModelPath();
+ output_model_path_ptr != nullptr) {
+ // Check before EP compile graphs
+ std::filesystem::path context_cache_path;
+ ORT_RETURN_IF_ERROR(GetValidatedEpContextPath(*output_model_path_ptr, graph.ModelPath(),
+ context_cache_path,
+ ep_context_gen_options.error_if_output_file_exists));
+ }
}
// We use this only if Resource Aware Partitioning is enabled for any of the EPs
diff --git a/onnxruntime/core/framework/graph_partitioner.h b/onnxruntime/core/framework/graph_partitioner.h
index 6e36d79701fd7..abe46cea58ab2 100644
--- a/onnxruntime/core/framework/graph_partitioner.h
+++ b/onnxruntime/core/framework/graph_partitioner.h
@@ -15,7 +15,10 @@ class ExecutionProviders;
class KernelRegistryManager;
class Model;
struct ConfigOptions;
-struct EpContextModelGenerationOptions;
+
+namespace epctx {
+struct ModelGenOptions;
+}
class GraphPartitioner {
public:
@@ -50,7 +53,7 @@ class GraphPartitioner {
const ConfigOptions& config_options,
const logging::Logger& logger,
Mode mode = Mode::kNormal,
- const EpContextModelGenerationOptions& ep_context_gen_options = {},
+ const epctx::ModelGenOptions& ep_context_gen_options = {},
const layout_transformation::DebugGraphFn& debug_graph_fn = {}) const;
bool IsLoadCancellationFlagSet() const {
diff --git a/onnxruntime/core/framework/session_options.cc b/onnxruntime/core/framework/session_options.cc
index 231eb47603838..63f928d52d788 100644
--- a/onnxruntime/core/framework/session_options.cc
+++ b/onnxruntime/core/framework/session_options.cc
@@ -99,20 +99,11 @@ void SessionOptions::AddCustomOpLibraryHandle(PathString library_name, void* lib
}
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
-EpContextModelGenerationOptions::EpContextModelGenerationOptions(const ConfigOptions& config_options) {
- enable = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1";
- output_model_file_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, "");
- output_external_initializers_file_path = config_options.GetConfigOrDefault(
- kOrtSessionOptionsEpContextModelExternalInitializersFileName, "");
- output_external_initializer_size_threshold = 0;
- embed_ep_context_in_model = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "0") == "1";
-}
-
-EpContextModelGenerationOptions SessionOptions::GetEpContextGenerationOptions() const {
+epctx::ModelGenOptions SessionOptions::GetEpContextGenerationOptions() const {
if (this->has_explicit_ep_context_gen_options) {
return this->ep_context_gen_options;
}
- return EpContextModelGenerationOptions(this->config_options);
+ return epctx::ModelGenOptions(this->config_options);
}
} // namespace onnxruntime
diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h
index b75eeb217e7f0..b328fc916f885 100644
--- a/onnxruntime/core/framework/session_options.h
+++ b/onnxruntime/core/framework/session_options.h
@@ -13,6 +13,7 @@
#include "core/common/inlined_containers.h"
#include "core/framework/allocator.h"
#include "core/framework/config_options.h"
+#include "core/framework/ep_context_options.h"
#include "core/framework/ort_value.h"
#include "core/session/onnxruntime_c_api.h"
#include "core/optimizer/graph_transformer_level.h"
@@ -70,53 +71,6 @@ struct FreeDimensionOverride {
using CheckLoadCancellationFn = std::function;
-///
-/// Options that configure the generation of a compiled model (i.e., a model with EPContext nodes).
-/// There are two ways to compile a model:
-/// 1. By specifying the correct session option configurations and creating an inference session.
-/// The compiled model is generated as a side-effect of session creation.
-/// 2. Using an explicit compile API (see OrtCompileApi struct in onnxruntime_c_api.h).
-///
-/// The default values in this struct are set to match the current/default behavior of approach 1 to maintain
-/// compatibility with the older way of compiling. The explicit compile API overrides some of these values to
-/// provide its own defaults (see core/session/model_compilation_options.h/cc).
-///
-struct EpContextModelGenerationOptions {
- // Action to take if the output model does not have compiled (EPContext) nodes.
- enum class ActionIfNoCompiledNodes {
- // Return OK() but don't generate an output model. Compiling via SessionOptions defaults to this behavior
- // to maintain compatibility. The explicit compile API does *not* use this action.
- kDontGenerateModel = 0,
-
- // Generate an output model even if it doesn't have compiled nodes.
- // The explicit Compile API defaults to this value.
- kGenerateModel,
-
- // Return an error if the model does not have compiled nodes.
- // The explicit Compile API can be configured to this value.
- kReturnError,
- };
-
- EpContextModelGenerationOptions() = default;
-
- // Initializes from string key/value pairs in session config options.
- // This initializes this struct from options set via the older compiling approach #1 above.
- explicit EpContextModelGenerationOptions(const ConfigOptions& config_options);
-
- bool enable = false;
- bool error_if_output_file_exists = true;
- ActionIfNoCompiledNodes action_if_no_compiled_nodes = ActionIfNoCompiledNodes::kDontGenerateModel;
- bool embed_ep_context_in_model = false;
-
- std::string output_model_file_path;
- void** output_model_buffer_ptr = nullptr;
- size_t* output_model_buffer_size_ptr = nullptr;
- AllocatorPtr output_model_buffer_allocator = nullptr;
-
- std::string output_external_initializers_file_path;
- size_t output_external_initializer_size_threshold = 0;
-};
-
struct EpSelectionPolicy {
// flag to detect that a policy was set by the user.
// need to preserve current behavior of defaulting to CPU EP if no EPs are explicitly registered
@@ -270,8 +224,8 @@ struct SessionOptions {
// The function GetEpContextGenerationOptions() handles conversion of string key/value pairs to the new
// struct type.
bool has_explicit_ep_context_gen_options = false;
- EpContextModelGenerationOptions ep_context_gen_options = {};
- EpContextModelGenerationOptions GetEpContextGenerationOptions() const;
+ epctx::ModelGenOptions ep_context_gen_options = {};
+ epctx::ModelGenOptions GetEpContextGenerationOptions() const;
};
inline std::ostream& operator<<(std::ostream& os, const SessionOptions& session_options) {
diff --git a/onnxruntime/core/framework/tensor_external_data_info.cc b/onnxruntime/core/framework/tensor_external_data_info.cc
index d7f5b23d56c70..dfdb3ba962609 100644
--- a/onnxruntime/core/framework/tensor_external_data_info.cc
+++ b/onnxruntime/core/framework/tensor_external_data_info.cc
@@ -18,6 +18,13 @@ using ::google::protobuf::RepeatedPtrField;
using ::ONNX_NAMESPACE::StringStringEntryProto;
namespace onnxruntime {
+ExternalDataInfo::ExternalDataInfo() = default;
+
+#if !defined(ORT_MINIMAL_BUILD)
+ExternalDataInfo::ExternalDataInfo(const PathString& rel_path, OFFSET_TYPE offset, size_t length)
+ : rel_path_(rel_path), offset_(offset), length_(length) {}
+#endif
+
Status ExternalDataInfo::Create(const RepeatedPtrField& input,
std::unique_ptr& external_data_info_result) {
auto external_data_info = std::make_unique();
diff --git a/onnxruntime/core/framework/tensor_external_data_info.h b/onnxruntime/core/framework/tensor_external_data_info.h
index 784b3f352a78e..aa9bb32922bd7 100644
--- a/onnxruntime/core/framework/tensor_external_data_info.h
+++ b/onnxruntime/core/framework/tensor_external_data_info.h
@@ -25,6 +25,12 @@ class ExternalDataInfo {
using OFFSET_TYPE = off_t;
#endif
+ ExternalDataInfo();
+
+#if !defined(ORT_MINIMAL_BUILD)
+ ExternalDataInfo(const PathString& rel_path, OFFSET_TYPE offset, size_t length);
+#endif
+
const PathString& GetRelPath() const { return rel_path_; }
OFFSET_TYPE GetOffset() const { return offset_; }
diff --git a/onnxruntime/core/graph/abi_graph_types.h b/onnxruntime/core/graph/abi_graph_types.h
index 2ef7c4a9091f3..c5d7d4cc4e68c 100644
--- a/onnxruntime/core/graph/abi_graph_types.h
+++ b/onnxruntime/core/graph/abi_graph_types.h
@@ -31,8 +31,10 @@ enum class OrtGraphIrApi {
kEpApi,
};
-// Alias OrtExternalInitializerInfo to the internal type.
-struct OrtExternalInitializerInfo : onnxruntime::ExternalDataInfo {};
+// Alias OrtExternalInitializerInfo to the internal onnxruntime::ExternalDataInfo type.
+struct OrtExternalInitializerInfo : onnxruntime::ExternalDataInfo {
+ using onnxruntime::ExternalDataInfo::ExternalDataInfo; // inherit constructors
+};
///
/// Public type that represents an ONNX value info.
diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc
index 0a228176175eb..9a97711996343 100644
--- a/onnxruntime/core/graph/graph.cc
+++ b/onnxruntime/core/graph/graph.cc
@@ -17,6 +17,7 @@
#include "core/common/logging/logging.h"
#include "core/common/narrow.h"
#include "core/flatbuffers/flatbuffers_utils.h"
+#include "core/framework/error_code_helper.h"
#include "core/framework/tensor_type_and_shape.h"
#include "core/flatbuffers/schema/ort.fbs.h"
#include "core/framework/tensor_external_data_info.h"
@@ -4357,14 +4358,23 @@ Status Graph::RegenerateInitializersAndReplaceInMemory(gsl::span& subgraphs) {
+ for (const auto& node : nodes) {
if (node.ContainsSubgraph()) {
// Let's find this node in the output_graph_proto
// The node name is optional, so we may need to check by the output value name
// given that they can only assigned once.
- auto hit = std::find_if(output_graph_proto.mutable_node()->begin(),
- output_graph_proto.mutable_node()->end(),
+ auto hit = std::find_if(graph_proto.mutable_node()->begin(),
+ graph_proto.mutable_node()->end(),
[&node](const ONNX_NAMESPACE::NodeProto& proto) {
const auto& node_name = node.Name();
if (!node_name.empty())
@@ -4372,7 +4382,7 @@ Status Graph::ProcessSubgraphsInMemoryData(ONNX_NAMESPACE::GraphProto& output_gr
return (proto.output_size() > 0 &&
proto.output(0) == node.OutputDefs()[0]->Name());
});
- ORT_RETURN_IF_NOT(hit != output_graph_proto.mutable_node()->end(), "Node ", node.Name(),
+ ORT_RETURN_IF_NOT(hit != graph_proto.mutable_node()->end(), "Node ", node.Name(),
" not found in output_graph_proto");
auto& result_node = *hit;
for (const auto& e : node.GetAttributeNameToSubgraphMap()) {
@@ -4387,12 +4397,28 @@ Status Graph::ProcessSubgraphsInMemoryData(ONNX_NAMESPACE::GraphProto& output_gr
ORT_RETURN_IF_NOT(sub_hit != result_node.mutable_attribute()->end() && utils::HasGraph(*sub_hit),
"Subgraph ", name, " is referred to in GetAttributeNameToSubgraphMap, but not found in node ",
node.Name(), " while attempting to recurse into it.");
- auto& result_subgraph = *sub_hit->mutable_g();
- ORT_RETURN_IF_ERROR(subgraph->ProcessSubgraphsInMemoryData(result_subgraph));
+ SubgraphWithMutableProto subgraph_result{sub_hit->mutable_g(), subgraph};
+ subgraphs.emplace_back(subgraph_result);
}
}
}
+ return Status::OK();
+}
+
+Status Graph::ProcessSubgraphsInMemoryData(ONNX_NAMESPACE::GraphProto& output_graph_proto) const {
+ // Process subgraphs recursively (bottom-up).
+ {
+ std::vector subgraphs;
+ ORT_RETURN_IF_ERROR(GetSubgraphsWithMatchingGraphProtos(Nodes(), output_graph_proto, subgraphs));
+
+ for (SubgraphWithMutableProto& subgraph_and_proto : subgraphs) {
+ gsl::not_null subgraph = subgraph_and_proto.subgraph;
+ gsl::not_null subgraph_proto = subgraph_and_proto.subgraph_proto;
+ ORT_RETURN_IF_ERROR(subgraph->ProcessSubgraphsInMemoryData(*subgraph_proto));
+ }
+ }
+
// Filter in iterators for weights that are present in the name_to_initial_tensor_ map
// and preserve the order. This is needed for tests.
InlinedVector initializers_to_process;
@@ -4444,44 +4470,19 @@ Status Graph::AddExternalInitializersToGraphProtoImpl(
// Process initializers in a subgraph, check their size and
// write to an external file. This function also saves pre-packed
// blobs for the initializer being saved to disk, if the initializer has any pre-packs.
- // This function is invoked by ToGraphProtoWithExternalInitiallizers() and processes subgraphs
+ // This function is invoked by ToGraphProtoWithExternalInitializers() and processes subgraphs
// bottom up.
- for (const auto& node : Nodes()) {
- if (node.ContainsSubgraph()) {
- // Let's find this node in the output_graph_proto
- // The node name is optional, so we may need to check by the output value name
- // given that they can only assigned once.
- auto hit = std::find_if(output_graph_proto.mutable_node()->begin(),
- output_graph_proto.mutable_node()->end(),
- [&node](const ONNX_NAMESPACE::NodeProto& proto) {
- const auto& node_name = node.Name();
- if (!node_name.empty())
- return proto.name() == node_name;
- return (proto.output_size() > 0 &&
- proto.output(0) == node.OutputDefs()[0]->Name());
- });
- ORT_RETURN_IF_NOT(hit != output_graph_proto.mutable_node()->end(), "Node ", node.Name(),
- " not found in output_graph_proto");
- auto& result_node = *hit;
- for (const auto& e : node.GetAttributeNameToSubgraphMap()) {
- const auto& name = e.first;
- const auto& subgraph = e.second;
- // Lets find this subgraph in the result_node
- auto sub_hit = std::find_if(result_node.mutable_attribute()->begin(),
- result_node.mutable_attribute()->end(),
- [&name](const ONNX_NAMESPACE::AttributeProto& proto) {
- return proto.name() == name;
- });
- ORT_RETURN_IF_NOT(sub_hit != result_node.mutable_attribute()->end() && utils::HasGraph(*sub_hit),
- "Subgraph ", name, " is referred to in GetAttributeNameToSubgraphMap, but not found in node ",
- node.Name(), " while attempting to recurse into it.");
- auto& result_subgraph = *sub_hit->mutable_g();
- ORT_RETURN_IF_ERROR(subgraph->AddExternalInitializersToGraphProtoImpl(
- model_path, external_file_path,
- model_external_file_path, model_saving_options,
- result_subgraph,
- external_stream, external_offset));
- }
+ {
+ std::vector subgraphs;
+ ORT_RETURN_IF_ERROR(GetSubgraphsWithMatchingGraphProtos(Nodes(), output_graph_proto, subgraphs));
+
+ for (SubgraphWithMutableProto& subgraph_and_proto : subgraphs) {
+ gsl::not_null subgraph = subgraph_and_proto.subgraph;
+ gsl::not_null subgraph_proto = subgraph_and_proto.subgraph_proto;
+ ORT_RETURN_IF_ERROR(subgraph->AddExternalInitializersToGraphProtoImpl(
+ model_path, external_file_path,
+ model_external_file_path, model_saving_options,
+ *subgraph_proto, external_stream, external_offset));
}
}
@@ -4643,6 +4644,113 @@ ONNX_NAMESPACE::GraphProto Graph::ToGraphProtoWithExternalInitializers(
return result;
}
+Status Graph::ToGraphProtoWithCustomInitializerHandlingImpl(
+ OrtGetInitializerLocationFunc handle_initializer_func,
+ void* state,
+ /*out*/ ONNX_NAMESPACE::GraphProto& output_graph_proto) const {
+ // This loop processes subgraphs bottom up.
+ {
+ std::vector subgraphs;
+ ORT_RETURN_IF_ERROR(GetSubgraphsWithMatchingGraphProtos(Nodes(), output_graph_proto, subgraphs));
+
+ for (SubgraphWithMutableProto& subgraph_and_proto : subgraphs) {
+ gsl::not_null subgraph = subgraph_and_proto.subgraph;
+ gsl::not_null subgraph_proto = subgraph_and_proto.subgraph_proto;
+ ORT_RETURN_IF_ERROR(subgraph->ToGraphProtoWithCustomInitializerHandlingImpl(handle_initializer_func,
+ state, *subgraph_proto));
+ }
+ }
+
+ // Create a sorted std::vector of initializers so that we always process them in a deterministic order.
+ InlinedVector initializers;
+ initializers.reserve(GetAllInitializedTensors().size());
+
+ for (const auto& [name, initializer_tp] : GetAllInitializedTensors()) {
+ initializers.push_back(initializer_tp);
+ }
+
+ std::sort(initializers.begin(), initializers.end(),
+ [](const ONNX_NAMESPACE::TensorProto* a, const ONNX_NAMESPACE::TensorProto* b) {
+ return a->name() < b->name();
+ });
+
+ // Call user's handler function for each initializer. We store the initializer externally
+ // or within the model depending on the result returned by the handler function.
+ for (gsl::not_null initializer : initializers) {
+#if !defined(DISABLE_SPARSE_TENSORS)
+ if (IsSparseInitializer(initializer->name())) {
+ // Sparse tensors are added to the ONNX file directly.
+ auto& sparse_initializer = *output_graph_proto.add_sparse_initializer();
+ ORT_RETURN_IF_ERROR(utils::DenseTensorToSparseTensorProto(*initializer, ModelPath(), sparse_initializer));
+ } else {
+#endif
+ TensorProto* output_proto = output_graph_proto.add_initializer();
+
+ output_proto->set_name(initializer->name());
+ output_proto->set_data_type(initializer->data_type());
+ for (int i = 0; i != initializer->dims_size(); ++i) {
+ output_proto->add_dims(initializer->dims(i));
+ }
+ output_proto->set_doc_string(initializer->doc_string());
+
+ OrtValue ort_value;
+ std::unique_ptr original_ext_data_info = nullptr;
+
+ if (utils::HasExternalDataInFile(*initializer)) {
+ // Initializer has data in an external file. Load it into OrtValue (potentially via memory mapping).
+ ORT_RETURN_IF_ERROR(ExternalDataInfo::Create(initializer->external_data(), original_ext_data_info));
+ ORT_RETURN_IF_ERROR(utils::GetExtDataFromTensorProto(Env::Default(), ModelPath(), *initializer, ort_value));
+ } else {
+ // Initializer is either stored inline within the TensorProto or it is "external data in memory".
+ // Get an OrtValue (if already loaded by Graph) or copy into an OrtValue otherwise.
+ bool graph_has_ort_value = GetOrtValueInitializer(initializer->name(), ort_value, /*check_outer_scope*/ false);
+ if (!graph_has_ort_value) {
+ assert(!utils::HasExternalData(*initializer));
+ ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), ModelPath(), *initializer,
+ CPUAllocator::DefaultInstance(), ort_value));
+ }
+ }
+
+ // Call the user's initializer handler function. If the user wants to store the initializer externally,
+ // the handler function will use OrtApi::CreateExternalInitializerInfo() to create a new
+ // OrtExternalInitializerInfo instance that indicates the location of the data.
+ OrtExternalInitializerInfo* new_external_info = nullptr;
+ Status status = ToStatusAndRelease(handle_initializer_func(state, initializer->name().c_str(),
+ &ort_value,
+ static_cast(original_ext_data_info.get()),
+ &new_external_info));
+
+ ORT_RETURN_IF(new_external_info != nullptr &&
+ new_external_info == static_cast(original_ext_data_info.get()),
+ "User's OrtGetInitializerLocationFunc must not return the external_info parameter.",
+ "Return a copy instead.");
+ std::unique_ptr new_external_info_holder(new_external_info); // Take ownership
+ ORT_RETURN_IF_ERROR(status);
+
+ if (new_external_info != nullptr) {
+ ExternalDataInfo::SetExternalLocationToProto(new_external_info->GetRelPath(), new_external_info->GetOffset(),
+ new_external_info->GetLength(), *output_proto);
+ } else {
+ const Tensor& tensor = ort_value.Get();
+ output_proto->clear_data_location();
+ utils::SetRawDataInTensorProto(*output_proto, tensor.DataRaw(), tensor.SizeInBytes());
+ }
+#if !defined(DISABLE_SPARSE_TENSORS)
+ }
+#endif
+ }
+
+ return Status::OK();
+}
+
+Status Graph::ToGraphProtoWithCustomInitializerHandling(OrtGetInitializerLocationFunc handle_initializer_func,
+ void* state,
+ /*out*/ ONNX_NAMESPACE::GraphProto& graph_proto) const {
+ ToGraphProtoInternal(graph_proto);
+ ORT_RETURN_IF_ERROR(ToGraphProtoWithCustomInitializerHandlingImpl(handle_initializer_func, state, graph_proto));
+ return Status::OK();
+}
+
void Graph::ToGraphProtoInternal(ONNX_NAMESPACE::GraphProto& graph_proto) const {
graph_proto_->clear_node();
graph_proto_->clear_input();
diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc
index eb5e1e89e2f9c..0ffbced51ee35 100644
--- a/onnxruntime/core/graph/model.cc
+++ b/onnxruntime/core/graph/model.cc
@@ -415,6 +415,25 @@ ModelProto Model::ToGraphProtoWithExternalInitializers(const std::filesystem::pa
return result;
}
+common::Status Model::ToGraphProtoWithCustomInitializerHandling(OrtGetInitializerLocationFunc handle_initializer_func,
+ void* state,
+ /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) const {
+ model_proto = model_proto_;
+
+ // Sync current model_metadata_ back to protobuf metadata_props
+ model_proto.clear_metadata_props();
+ for (const auto& metadata : model_metadata_) {
+ const gsl::not_null prop{model_proto.add_metadata_props()};
+ prop->set_key(metadata.first);
+ prop->set_value(metadata.second);
+ }
+
+ const auto& graph = *graph_;
+ ORT_RETURN_IF_ERROR(graph.ToGraphProtoWithCustomInitializerHandling(handle_initializer_func,
+ state, *model_proto.mutable_graph()));
+ return Status::OK();
+}
+
Status Model::Load(std::istream& model_istream, ModelProto* p_model_proto) {
if (!model_istream.good()) {
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid istream object.");
diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h
index e8722f6f5c0b2..c86aac44806bd 100644
--- a/onnxruntime/core/graph/model.h
+++ b/onnxruntime/core/graph/model.h
@@ -210,6 +210,18 @@ class Model {
const std::filesystem::path& file_path,
const ModelSavingOptions& model_saving_options) const;
+ ///
+ /// Serialize the Model to a onnx::ModelProto. Caller provides a function that determines where each initializer
+ /// is stored (i.e., either in an external file or within the model).
+ ///
+ /// Function called for every initializer.
+ /// Opaque user state passed to the handle_initializer_func.
+ /// Output parameter set to the serialized onnx::ModelProto.
+ /// A status indicating success or an error.
+ common::Status ToGraphProtoWithCustomInitializerHandling(OrtGetInitializerLocationFunc handle_initializer_func,
+ void* state,
+ /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) const;
+
static common::Status Save(Model& model, const PathString& file_path);
static common::Status Save(Model& model, int fd);
diff --git a/onnxruntime/core/session/compile_api.cc b/onnxruntime/core/session/compile_api.cc
index 759773042debb..b9a54ea7104e1 100644
--- a/onnxruntime/core/session/compile_api.cc
+++ b/onnxruntime/core/session/compile_api.cc
@@ -64,7 +64,7 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetInputModelPath,
API_IMPL_BEGIN
#if !defined(ORT_MINIMAL_BUILD)
auto model_compile_options = reinterpret_cast(ort_model_compile_options);
- std::string model_path = PathToUTF8String(input_model_path);
+ std::filesystem::path model_path = input_model_path;
if (model_path.empty()) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid input model: path string is empty");
@@ -113,7 +113,7 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelPath,
#if !defined(ORT_MINIMAL_BUILD)
auto model_compile_options = reinterpret_cast(ort_model_compile_options);
- std::string model_path = PathToUTF8String(output_model_path);
+ std::filesystem::path model_path = output_model_path;
if (model_path.empty()) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid output model path: path is empty");
}
@@ -136,17 +136,18 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetEpContextBinaryInf
#if !defined(ORT_MINIMAL_BUILD)
auto model_compile_options = reinterpret_cast(ort_model_compile_options);
- std::string output_dir = PathToUTF8String(output_directory);
- if (output_dir.empty()) {
+ std::filesystem::path output_directory_path = output_directory;
+ if (output_directory_path.empty()) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid output directory: path is empty");
}
- std::string model_name_str = ToUTF8String(model_name);
- if (model_name_str.empty()) {
+ std::filesystem::path model_name_path = model_name;
+ if (model_name_path.empty()) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid model name: string is empty");
}
- ORT_API_RETURN_IF_STATUS_NOT_OK(model_compile_options->SetEpContextBinaryInformation(output_dir, model_name_str));
+ ORT_API_RETURN_IF_STATUS_NOT_OK(model_compile_options->SetEpContextBinaryInformation(output_directory_path,
+ model_name_path));
return nullptr;
#else
ORT_UNUSED_PARAMETER(ort_model_compile_options);
@@ -163,7 +164,7 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelExterna
size_t external_initializer_size_threshold) {
API_IMPL_BEGIN
#if !defined(ORT_MINIMAL_BUILD)
- std::string initializers_file_path = PathToUTF8String(external_initializers_file_path);
+ std::filesystem::path initializers_file_path = external_initializers_file_path;
if (initializers_file_path.empty()) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid external initializer file: path is empty");
}
@@ -214,6 +215,50 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelBuffer,
API_IMPL_END
}
+ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelWriteFunc,
+ _In_ OrtModelCompilationOptions* ort_model_compile_options,
+ _In_ OrtWriteBufferFunc write_func, _In_ void* state) {
+ API_IMPL_BEGIN
+#if !defined(ORT_MINIMAL_BUILD)
+ auto model_compile_options = reinterpret_cast(ort_model_compile_options);
+
+ if (write_func == nullptr) {
+ return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "OrtWriteBufferFunc function for output model is null");
+ }
+
+ model_compile_options->SetOutputModelWriteFunc(write_func, state);
+ return nullptr;
+#else
+ ORT_UNUSED_PARAMETER(ort_model_compile_options);
+ ORT_UNUSED_PARAMETER(write_func);
+ ORT_UNUSED_PARAMETER(state);
+ return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Compile API is not supported in this build");
+#endif // !defined(ORT_MINIMAL_BUILD)
+ API_IMPL_END
+}
+ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc,
+ _In_ OrtModelCompilationOptions* ort_model_compile_options,
+ _In_ OrtGetInitializerLocationFunc get_initializer_location_func, _In_ void* state) {
+ API_IMPL_BEGIN
+#if !defined(ORT_MINIMAL_BUILD)
+ auto model_compile_options = reinterpret_cast(ort_model_compile_options);
+
+ if (get_initializer_location_func == nullptr) {
+ return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
+ "OrtGetInitializerLocationFunc function for output model is null");
+ }
+
+ model_compile_options->SetOutputModelGetInitializerLocationFunc(get_initializer_location_func, state);
+ return nullptr;
+#else
+ ORT_UNUSED_PARAMETER(ort_model_compile_options);
+ ORT_UNUSED_PARAMETER(get_initializer_location_func);
+ ORT_UNUSED_PARAMETER(state);
+ return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Compile API is not supported in this build");
+#endif // !defined(ORT_MINIMAL_BUILD)
+ API_IMPL_END
+}
+
ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetEpContextEmbedMode,
_In_ OrtModelCompilationOptions* ort_model_compile_options,
bool embed_ep_context_in_model) {
@@ -295,6 +340,8 @@ static constexpr OrtCompileApi ort_compile_api = {
&OrtCompileAPI::ModelCompilationOptions_SetFlags,
&OrtCompileAPI::ModelCompilationOptions_SetEpContextBinaryInformation,
&OrtCompileAPI::ModelCompilationOptions_SetGraphOptimizationLevel,
+ &OrtCompileAPI::ModelCompilationOptions_SetOutputModelWriteFunc,
+ &OrtCompileAPI::ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc,
};
// checks that we don't violate the rule that the functions must remain in the slots they were originally assigned
diff --git a/onnxruntime/core/session/compile_api.h b/onnxruntime/core/session/compile_api.h
index 51cf71cd6ec61..34fa06340a7f9 100644
--- a/onnxruntime/core/session/compile_api.h
+++ b/onnxruntime/core/session/compile_api.h
@@ -35,5 +35,11 @@ ORT_API_STATUS_IMPL(ModelCompilationOptions_SetEpContextBinaryInformation, _In_
ORT_API_STATUS_IMPL(ModelCompilationOptions_SetGraphOptimizationLevel,
_In_ OrtModelCompilationOptions* model_compile_options,
_In_ GraphOptimizationLevel graph_optimization_level);
+ORT_API_STATUS_IMPL(ModelCompilationOptions_SetOutputModelWriteFunc,
+ _In_ OrtModelCompilationOptions* model_compile_options,
+ _In_ OrtWriteBufferFunc write_func, _In_ void* state);
+ORT_API_STATUS_IMPL(ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc,
+ _In_ OrtModelCompilationOptions* model_compile_options,
+ _In_ OrtGetInitializerLocationFunc get_initializer_location_func, _In_ void* state);
} // namespace OrtCompileAPI
diff --git a/onnxruntime/core/session/model_compilation_options.cc b/onnxruntime/core/session/model_compilation_options.cc
index 3ada35eeaff63..84f41771cb62b 100644
--- a/onnxruntime/core/session/model_compilation_options.cc
+++ b/onnxruntime/core/session/model_compilation_options.cc
@@ -7,8 +7,11 @@
#include
#include
#include
+#include
+#include "core/common/path_string.h"
#include "core/framework/allocator.h"
+#include "core/framework/ep_context_options.h"
#include "core/session/onnxruntime_session_options_config_keys.h"
#include "core/session/environment.h"
@@ -22,7 +25,7 @@ ModelCompilationOptions::ModelCompilationOptions(const onnxruntime::Environment&
// defaulting to kGenerateModel to support wider usage.
session_options_.value.ep_context_gen_options.action_if_no_compiled_nodes =
- EpContextModelGenerationOptions::ActionIfNoCompiledNodes::kGenerateModel;
+ epctx::ModelGenOptions::ActionIfNoCompiledNodes::kGenerateModel;
// Shouldn't fail because the key/value strings are below the maximum string length limits in ConfigOptions.
ORT_ENFORCE(session_options_.value.config_options.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1").IsOK());
@@ -31,7 +34,7 @@ ModelCompilationOptions::ModelCompilationOptions(const onnxruntime::Environment&
session_options_.value.graph_optimization_level = TransformerLevel::Default; // L0: required transformers only
}
-void ModelCompilationOptions::SetInputModelPath(const std::string& input_model_path) {
+void ModelCompilationOptions::SetInputModelPath(const std::filesystem::path& input_model_path) {
ResetInputModelSettings();
input_model_path_ = input_model_path;
}
@@ -42,17 +45,16 @@ void ModelCompilationOptions::SetInputModelFromBuffer(const void* input_model_da
input_model_data_size_ = input_model_data_size;
}
-Status ModelCompilationOptions::SetOutputModelPath(const std::string& output_model_path) {
- ORT_RETURN_IF_ERROR(ResetOutputModelSettings());
-
+Status ModelCompilationOptions::SetOutputModelPath(const std::filesystem::path& output_model_path) {
ConfigOptions& config_options = session_options_.value.config_options;
- EpContextModelGenerationOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options;
+ epctx::ModelGenOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options;
+
+ ep_context_gen_options.output_model_location = output_model_path;
- ep_context_gen_options.output_model_file_path = output_model_path;
+ std::string output_model_path_str = PathToUTF8String(output_model_path);
- if (ep_context_gen_options.output_model_file_path.size() <= ConfigOptions::kMaxValueLength) {
- Status status = config_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath,
- ep_context_gen_options.output_model_file_path.c_str());
+ if (output_model_path_str.size() <= ConfigOptions::kMaxValueLength) {
+ Status status = config_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, output_model_path_str.c_str());
ORT_ENFORCE(status.IsOK()); // Should not fail because both key/value strings are below the min string lengths
// required by ConfigOptions::AddConfigEntry().
} else {
@@ -73,7 +75,7 @@ Status ModelCompilationOptions::SetOutputModelPath(const std::string& output_mod
logging::LoggingManager* log_manager = env_.GetLoggingManager();
if (log_manager != nullptr && log_manager->HasDefaultLogger()) {
const logging::Logger& logger = log_manager->DefaultLogger();
- LOGS(logger, WARNING) << "Output model path length (" << ep_context_gen_options.output_model_file_path.size()
+ LOGS(logger, WARNING) << "Output model path length (" << output_model_path_str.size()
<< ") exceeds limit of " << ConfigOptions::kMaxValueLength << " characters."
<< "ORT will still generate the expected output file, but EPs will see an empty "
<< "output model path in SessionOption's ConfigOptions.";
@@ -82,40 +84,58 @@ Status ModelCompilationOptions::SetOutputModelPath(const std::string& output_mod
return Status::OK();
}
-void ModelCompilationOptions::SetOutputModelExternalInitializersFile(const std::string& external_initializers_path,
- size_t external_initializer_size_threshold) {
- session_options_.value.ep_context_gen_options.output_external_initializers_file_path = external_initializers_path;
- session_options_.value.ep_context_gen_options.output_external_initializer_size_threshold =
- external_initializer_size_threshold;
+void ModelCompilationOptions::SetOutputModelExternalInitializersFile(
+ const std::filesystem::path& external_initializers_path,
+ size_t external_initializer_size_threshold) {
+ session_options_.value.ep_context_gen_options.initializers_location = epctx::ExternalInitializerFileInfo{
+ external_initializers_path,
+ external_initializer_size_threshold,
+ };
}
Status ModelCompilationOptions::SetOutputModelBuffer(onnxruntime::AllocatorPtr allocator,
void** output_model_buffer_ptr,
size_t* output_model_buffer_size_ptr) {
- ORT_RETURN_IF_ERROR(ResetOutputModelSettings());
+ session_options_.value.ep_context_gen_options.output_model_location = epctx::BufferHolder{
+ output_model_buffer_ptr,
+ output_model_buffer_size_ptr,
+ std::move(allocator),
+ };
- session_options_.value.ep_context_gen_options.output_model_buffer_ptr = output_model_buffer_ptr;
- session_options_.value.ep_context_gen_options.output_model_buffer_size_ptr = output_model_buffer_size_ptr;
- session_options_.value.ep_context_gen_options.output_model_buffer_allocator = std::move(allocator);
return Status::OK();
}
-Status ModelCompilationOptions::SetEpContextBinaryInformation(const std::string& output_directory,
- const std::string& model_name) {
+void ModelCompilationOptions::SetOutputModelWriteFunc(OrtWriteBufferFunc write_func, void* state) {
+ session_options_.value.ep_context_gen_options.output_model_location = epctx::BufferWriteFuncHolder{
+ write_func,
+ state,
+ };
+}
+
+void ModelCompilationOptions::SetOutputModelGetInitializerLocationFunc(
+ OrtGetInitializerLocationFunc get_initializer_location_func, void* state) {
+ session_options_.value.ep_context_gen_options.initializers_location = epctx::InitializerHandler{
+ get_initializer_location_func,
+ state,
+ };
+}
+
+Status ModelCompilationOptions::SetEpContextBinaryInformation(const std::filesystem::path& output_directory,
+ const std::filesystem::path& model_name) {
if (output_directory.empty() || model_name.empty()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "output_dir or model_name is empty.");
}
- std::filesystem::path output_dir_path(output_directory);
- if (output_dir_path.has_filename() && output_dir_path.extension() == "") {
+ if (output_directory.has_filename() && output_directory.extension() == "") {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "output_dir is not a valid directory.");
}
- std::filesystem::path ctx_model_path = output_directory / std::filesystem::path(model_name);
+ std::filesystem::path ctx_model_path = output_directory / model_name;
+ std::string ctx_model_path_str = PathToUTF8String(ctx_model_path);
- if (ctx_model_path.string().size() <= ConfigOptions::kMaxValueLength) {
+ if (ctx_model_path_str.size() <= ConfigOptions::kMaxValueLength) {
ORT_RETURN_IF_ERROR(session_options_.value.config_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath,
- ctx_model_path.string().c_str()));
+ ctx_model_path_str.c_str()));
} else {
logging::LoggingManager* log_manager = env_.GetLoggingManager();
if (log_manager != nullptr && log_manager->HasDefaultLogger()) {
@@ -138,11 +158,11 @@ Status ModelCompilationOptions::SetEpContextEmbedMode(bool embed_ep_context_in_m
}
Status ModelCompilationOptions::SetFlags(uint32_t flags) {
- EpContextModelGenerationOptions& options = session_options_.value.ep_context_gen_options;
+ epctx::ModelGenOptions& options = session_options_.value.ep_context_gen_options;
options.error_if_output_file_exists = flags & OrtCompileApiFlags_ERROR_IF_OUTPUT_FILE_EXISTS;
options.action_if_no_compiled_nodes =
- (flags & OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED) ? EpContextModelGenerationOptions::ActionIfNoCompiledNodes::kReturnError
- : EpContextModelGenerationOptions::ActionIfNoCompiledNodes::kGenerateModel;
+ (flags & OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED) ? epctx::ModelGenOptions::ActionIfNoCompiledNodes::kReturnError
+ : epctx::ModelGenOptions::ActionIfNoCompiledNodes::kGenerateModel;
return Status::OK();
}
@@ -154,7 +174,7 @@ bool ModelCompilationOptions::InputModelComesFromFile() const {
return !input_model_path_.empty();
}
-const std::string& ModelCompilationOptions::GetInputModelPath() const {
+const std::filesystem::path& ModelCompilationOptions::GetInputModelPath() const {
return input_model_path_;
}
@@ -200,77 +220,78 @@ Status ModelCompilationOptions::SetGraphOptimizationLevel(GraphOptimizationLevel
return Status::OK();
}
-Status ModelCompilationOptions::ResetOutputModelSettings() {
- EpContextModelGenerationOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options;
- ep_context_gen_options.output_model_file_path.clear();
- ep_context_gen_options.output_model_buffer_ptr = nullptr;
- ep_context_gen_options.output_model_buffer_size_ptr = nullptr;
- ep_context_gen_options.output_model_buffer_allocator = nullptr;
- return Status::OK();
-}
+Status ModelCompilationOptions::Check() const {
+ const ConfigOptions& config_options = session_options_.value.config_options;
+
+ ORT_ENFORCE(session_options_.value.ep_context_gen_options.enable);
+ ORT_ENFORCE(config_options.GetConfigOrDefault(kOrtSessionOptionsDisableModelCompile, "0") == "0");
-Status ModelCompilationOptions::CheckInputModelSettings() const {
- const bool comes_from_file = !input_model_path_.empty();
- const bool comes_from_memory = input_model_data_ != nullptr;
+ // Check input model settings.
+ const bool input_from_file = !input_model_path_.empty();
+ const bool input_from_memory = input_model_data_ != nullptr;
- if (!comes_from_file && !comes_from_memory) {
+ if (!input_from_file && !input_from_memory) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input model to compile must be loaded from either a file or a memory buffer");
}
- if (comes_from_file && comes_from_memory) {
+ if (input_from_file && input_from_memory) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input model to compile must be loaded from either a file or a memory buffer, ",
"but not both.");
}
- if (comes_from_file && !std::filesystem::exists(input_model_path_)) {
+ if (input_from_file && !std::filesystem::exists(input_model_path_)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input model path does not exist: ", input_model_path_);
}
- if (comes_from_memory && input_model_data_size_ == 0) {
+ if (input_from_memory && input_model_data_size_ == 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Buffer for input model data has size 0");
}
- return Status::OK();
-}
-
-Status ModelCompilationOptions::CheckOutputModelSettings() const {
- const EpContextModelGenerationOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options;
+ // Check output model settings.
+ const epctx::ModelGenOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options;
+ bool has_no_output_model_location = std::holds_alternative(
+ ep_context_gen_options.output_model_location);
- const bool explicit_writes_to_file = !ep_context_gen_options.output_model_file_path.empty();
- const bool writes_to_buffer = ep_context_gen_options.output_model_buffer_ptr != nullptr;
-
- if (!explicit_writes_to_file && !writes_to_buffer) {
- // User did not specify an output file or an output buffer. We default to generating an output file
- // with a name based on the input file name, so do not return an error.
+ if (has_no_output_model_location && input_from_file) {
+ // User did not specify an output file, an output buffer, or an output write function. We default to generating an
+ // output file with a name based on the input file name, so do not return an error.
return Status::OK();
}
- if (explicit_writes_to_file && writes_to_buffer) {
+ if (has_no_output_model_location) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Output model to compile must be saved either to a file or to a buffer, but not both.");
+ "Unable to generate an output model path: require an input model path if the location "
+ "of the output model (e.g., file, buffer, or stream) is not specified.");
}
- if (writes_to_buffer && ep_context_gen_options.output_model_buffer_size_ptr == nullptr) {
+ const epctx::BufferHolder* output_buffer_ptr = ep_context_gen_options.TryGetOutputModelBuffer();
+
+ if (output_buffer_ptr != nullptr && output_buffer_ptr->buffer_ptr == nullptr) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Invalid buffer configuration for output model: buffer pointer is null");
+ }
+
+ if (output_buffer_ptr != nullptr && output_buffer_ptr->buffer_size_ptr == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Invalid buffer configuration for output model: size pointer is null");
}
- if (writes_to_buffer && ep_context_gen_options.output_model_buffer_allocator == nullptr) {
+ if (output_buffer_ptr != nullptr && output_buffer_ptr->buffer_allocator == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Invalid buffer configuration for output model: allocator is null");
}
- return Status::OK();
-}
+ const epctx::BufferWriteFuncHolder* output_write_func_holder = ep_context_gen_options.TryGetOutputModelWriteFunc();
+
+ if (output_write_func_holder != nullptr && output_write_func_holder->write_func == nullptr) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Invalid buffer writing function for output model: function pointer is null");
+ }
-Status ModelCompilationOptions::Check() const {
- ORT_ENFORCE(session_options_.value.ep_context_gen_options.enable);
- ORT_ENFORCE(session_options_.value.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableModelCompile, "0") == "0");
- ORT_RETURN_IF_ERROR(CheckInputModelSettings());
- ORT_RETURN_IF_ERROR(CheckOutputModelSettings());
return Status::OK();
}
+
} // namespace onnxruntime
#endif // !defined(ORT_MINIMAL_BUILD)
diff --git a/onnxruntime/core/session/model_compilation_options.h b/onnxruntime/core/session/model_compilation_options.h
index cd9091561af79..45323e6cb13c5 100644
--- a/onnxruntime/core/session/model_compilation_options.h
+++ b/onnxruntime/core/session/model_compilation_options.h
@@ -4,6 +4,7 @@
#if !defined(ORT_MINIMAL_BUILD)
#pragma once
+#include
#include
#include
#include "core/common/status.h"
@@ -34,7 +35,7 @@ class ModelCompilationOptions {
/// Overrides any previous call to SetInputModelPath() or SetInputModelFromBuffer().
///
/// The input model's path
- void SetInputModelPath(const std::string& input_model_path);
+ void SetInputModelPath(const std::filesystem::path& input_model_path);
///
/// Sets the buffer that stores the input ONNX model to compile.
@@ -50,7 +51,7 @@ class ModelCompilationOptions {
///
///
/// Status indicating potential error
- Status SetOutputModelPath(const std::string& output_model_path);
+ Status SetOutputModelPath(const std::filesystem::path& output_model_path);
///
/// Sets the file path to the file that will store external ONNX initializers for the compiled model.
@@ -58,7 +59,7 @@ class ModelCompilationOptions {
///
/// Path to the external initializers file to generate
/// Initializers that exceed this threshold are external
- void SetOutputModelExternalInitializersFile(const std::string& external_initializers_path,
+ void SetOutputModelExternalInitializersFile(const std::filesystem::path& external_initializers_path,
size_t external_initializer_size_threshold);
///
@@ -72,6 +73,21 @@ class ModelCompilationOptions {
Status SetOutputModelBuffer(onnxruntime::AllocatorPtr allocator, void** output_model_buffer_ptr,
size_t* output_model_buffer_size_ptr);
+ ///
+ /// Sets an output stream (write function + state) used to write out the compiled model bytes.
+ ///
+ /// Write function
+ /// The user's state
+ void SetOutputModelWriteFunc(OrtWriteBufferFunc write_func, void* state);
+
+ ///
+ /// Sets a user-provided function to handle serialization of ONNX initializers.
+ ///
+ /// The user-provided function called for every initializer
+ /// The user's state.
+ void SetOutputModelGetInitializerLocationFunc(OrtGetInitializerLocationFunc get_initializer_location_func,
+ void* state);
+
///
/// Sets information relate to EP context binary file.
/// EP use this information to decide the location and context binary file name.
@@ -80,7 +96,8 @@ class ModelCompilationOptions {
/// The folder path to the generated context binary file
/// Model name used to decide the context binary file name: [model_name]_[ep].bin
/// Status indicating potential error
- Status SetEpContextBinaryInformation(const std::string& output_directory, const std::string& model_name);
+ Status SetEpContextBinaryInformation(const std::filesystem::path& output_directory,
+ const std::filesystem::path& model_name);
///
/// Enables or disables the embedding of EPContext binary data into the `ep_cache_context` attribute of EPContext
@@ -107,7 +124,7 @@ class ModelCompilationOptions {
/// Returns the file path to the input ONNX model.
///
/// input model's path
- const std::string& GetInputModelPath() const;
+ const std::filesystem::path& GetInputModelPath() const;
///
/// Returns true if the input model is read from a file.
@@ -144,13 +161,10 @@ class ModelCompilationOptions {
private:
void ResetInputModelSettings();
- Status ResetOutputModelSettings();
- Status CheckInputModelSettings() const;
- Status CheckOutputModelSettings() const;
const onnxruntime::Environment& env_;
OrtSessionOptions session_options_;
- std::string input_model_path_;
+ std::filesystem::path input_model_path_;
const void* input_model_data_ = nullptr;
size_t input_model_data_size_ = 0;
};
diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc
index f3e2a8ce7ba7b..36f7f1f60c36e 100644
--- a/onnxruntime/core/session/onnxruntime_c_api.cc
+++ b/onnxruntime/core/session/onnxruntime_c_api.cc
@@ -2538,6 +2538,23 @@ ORT_API(void, OrtApis::ReleaseExternalInitializerInfo, _Frees_ptr_opt_ OrtExtern
delete static_cast(info);
}
+ORT_API_STATUS_IMPL(OrtApis::CreateExternalInitializerInfo, _In_ const ORTCHAR_T* filepath,
+ _In_ int64_t file_offset, _In_ size_t byte_size, _Outptr_ OrtExternalInitializerInfo** out) {
+ API_IMPL_BEGIN
+#if !defined(ORT_MINIMAL_BUILD)
+ auto ext_data_info = std::make_unique(filepath, file_offset, byte_size);
+ *out = ext_data_info.release();
+ return nullptr;
+#else
+ ORT_UNUSED_PARAMETER(filepath);
+ ORT_UNUSED_PARAMETER(file_offset);
+ ORT_UNUSED_PARAMETER(byte_size);
+ ORT_UNUSED_PARAMETER(out);
+ return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "CreateExternalInitializerInfo() is not supported in this build.");
+#endif
+ API_IMPL_END
+}
+
ORT_API(const ORTCHAR_T*, OrtApis::ExternalInitializerInfo_GetFilePath, _In_ const OrtExternalInitializerInfo* info) {
return info->GetRelPath().c_str();
}
@@ -4202,6 +4219,7 @@ static constexpr OrtApi ort_api_1_to_23 = {
&OrtApis::Graph_GetModelMetadata,
&OrtApis::GetModelCompatibilityForEpDevices,
+ &OrtApis::CreateExternalInitializerInfo,
};
// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.
diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h
index 6dc4cf9d195cc..78616c7b3973e 100644
--- a/onnxruntime/core/session/ort_apis.h
+++ b/onnxruntime/core/session/ort_apis.h
@@ -700,6 +700,8 @@ ORT_API_STATUS_IMPL(Node_GetEpName, _In_ const OrtNode* node, _Outptr_result_may
// OrtExternalInitializerInfo
ORT_API(void, ReleaseExternalInitializerInfo, _Frees_ptr_opt_ OrtExternalInitializerInfo* info);
+ORT_API_STATUS_IMPL(CreateExternalInitializerInfo, _In_ const ORTCHAR_T* filepath, _In_ int64_t file_offset,
+ _In_ size_t byte_size, _Outptr_ OrtExternalInitializerInfo** out);
ORT_API(const ORTCHAR_T*, ExternalInitializerInfo_GetFilePath, _In_ const OrtExternalInitializerInfo* info);
ORT_API(int64_t, ExternalInitializerInfo_GetFileOffset, _In_ const OrtExternalInitializerInfo* info);
ORT_API(size_t, ExternalInitializerInfo_GetByteSize, _In_ const OrtExternalInitializerInfo* info);
diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc
index 7da7fabb15b15..444027692903c 100644
--- a/onnxruntime/core/session/utils.cc
+++ b/onnxruntime/core/session/utils.cc
@@ -136,13 +136,11 @@ static OrtStatus* CreateSessionAndLoadModelImpl(_In_ const OrtSessionOptions* op
// If ep.context_enable is set, then ep.context_file_path is expected, otherwise ORT don't know where to generate the _ctx.onnx file
if (options && model_path == nullptr) {
- EpContextModelGenerationOptions ep_ctx_gen_options = options->value.GetEpContextGenerationOptions();
+ epctx::ModelGenOptions ep_ctx_gen_options = options->value.GetEpContextGenerationOptions();
// This is checked by the OrtCompileApi's CompileModel() function, but we check again here in case
// the user used the older SessionOptions' configuration entries to generate a compiled model.
- if (ep_ctx_gen_options.enable &&
- ep_ctx_gen_options.output_model_file_path.empty() &&
- ep_ctx_gen_options.output_model_buffer_ptr == nullptr) {
+ if (ep_ctx_gen_options.enable && !ep_ctx_gen_options.HasOutputModelLocation()) {
return OrtApis::CreateStatus(ORT_FAIL,
"Inference session was configured with EPContext model generation enabled but "
"without a valid location (e.g., file or buffer) for the output model. "
@@ -383,7 +381,7 @@ Status CompileModel(const Environment& env, const ModelCompilationOptions& model
const OrtSessionOptions* session_options = &model_compile_options.GetSessionOptions();
if (model_compile_options.InputModelComesFromFile()) {
- PathString input_model_path = ToPathString(model_compile_options.GetInputModelPath());
+ const std::filesystem::path& input_model_path = model_compile_options.GetInputModelPath();
ORT_RETURN_IF_ERROR(ToStatusAndRelease(CreateSessionAndLoadModelImpl(session_options, env,
input_model_path.c_str(),
nullptr, 0, session)));
diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py
index 8c8ba214eb714..35abad5760c32 100644
--- a/onnxruntime/python/onnxruntime_inference_collection.py
+++ b/onnxruntime/python/onnxruntime_inference_collection.py
@@ -9,7 +9,7 @@
import os
import typing
import warnings
-from collections.abc import Sequence
+from collections.abc import Callable, Sequence
from typing import Any
from onnxruntime.capi import _pybind_state as C
@@ -620,6 +620,36 @@ def _register_ep_custom_ops(self, session_options, providers, provider_options,
C.register_nv_tensorrt_rtx_plugins_as_custom_ops(session_options, providers[i][1])
+def make_get_initializer_location_func_wrapper(
+ get_initializer_location_func: GetInitializerLocationFunc,
+) -> GetInitializerLocationWrapperFunc:
+ """
+ Wraps a user's "get initializer location" function. The returned wrapper function adheres to the
+ signature expected by ORT.
+
+ Need this wrapper to:
+ - Convert the `initializer_value` parameter from `C.OrtValue` to `onnxruntime.OrtValue`, which is more
+ convenient for the user's function to use.
+ - Allow the user's function to return the original `external_info` parameter (this wrapper makes a copy)
+ """
+
+ def get_initializer_location_func_wrapper(
+ initializer_name: str,
+ initializer_value: C.OrtValue,
+ external_info: C.OrtExternalInitializerInfo | None,
+ ) -> C.OrtExternalInitializerInfo | None:
+ ret_val: C.OrtExternalInitializerInfo | None = get_initializer_location_func(
+ initializer_name, OrtValue(initializer_value), external_info
+ )
+ if ret_val is not None and ret_val == external_info:
+ # User returned `external_info` (const and owned by ORT). ORT expects the returned value to be
+ # a new instance (that it deletes), so make a copy.
+ ret_val = C.OrtExternalInitializerInfo(ret_val.filepath, ret_val.file_offset, ret_val.byte_size)
+ return ret_val
+
+ return get_initializer_location_func_wrapper
+
+
class ModelCompiler:
"""
This class is used to compile an ONNX model. A compiled ONNX model has EPContext nodes that each
@@ -648,6 +678,7 @@ def __init__(
external_initializers_size_threshold: int = 1024,
flags: int = C.OrtCompileApiFlags.NONE,
graph_optimization_level: C.GraphOptimizationLevel = C.GraphOptimizationLevel.ORT_DISABLE_ALL,
+ get_initializer_location_func: GetInitializerLocationFunc | None = None,
):
"""
Creates a ModelCompiler instance.
@@ -666,6 +697,25 @@ def __init__(
flags in onnxruntime.OrtCompileApiFlags.
:param graph_optimization_level: The graph optimization level.
Defaults to onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL.
+ :param get_initializer_location_func: Optional function called for every initializer to allow user to specify
+ whether an initializer should be stored within the model or externally. Example:
+ ```
+ def get_initializer_location(
+ initializer_name: str,
+ initializer_value: onnxrt.OrtValue,
+ external_info: onnxrt.OrtExternalInitializerInfo | None,
+ ) -> onnxrt.OrtExternalInitializerInfo | None:
+ byte_size = initializer_value.tensor_size_in_bytes()
+
+ if byte_size < 64:
+ return None # Store small initializer within compiled model.
+
+ # Else, write initializer to new external file.
+ value_np = initializer_value.numpy()
+ file_offset = ext_init_file.tell()
+ ext_init_file.write(value_np.tobytes())
+ return onnxrt.OrtExternalInitializerInfo(initializer_file_path, file_offset, byte_size)
+ ```
"""
input_model_path: str | os.PathLike | None = None
input_model_bytes: bytes | None = None
@@ -688,6 +738,18 @@ def __init__(
else:
external_initializers_file_path = ""
+ if get_initializer_location_func is not None:
+ if external_initializers_file_path:
+ raise ValueError(
+ "Cannot initialize ModelCompiler with both `external_initializers_file_path` "
+ "and `get_initializer_location_func`"
+ )
+ self.get_initializer_location_func_wrapper = make_get_initializer_location_func_wrapper(
+ get_initializer_location_func
+ )
+ else:
+ self.get_initializer_location_func_wrapper = None
+
if input_model_path:
self._model_compiler = C.ModelCompiler(
sess_options,
@@ -698,6 +760,7 @@ def __init__(
external_initializers_size_threshold,
flags,
graph_optimization_level,
+ self.get_initializer_location_func_wrapper,
)
else:
self._model_compiler = C.ModelCompiler(
@@ -709,6 +772,7 @@ def __init__(
external_initializers_size_threshold,
flags,
graph_optimization_level,
+ self.get_initializer_location_func_wrapper,
)
def compile_to_file(self, output_model_path: str | None = None):
@@ -738,6 +802,14 @@ def compile_to_bytes(self) -> bytes:
"""
return self._model_compiler.compile_to_bytes()
+ def compile_to_stream(self, write_function: Callable[[bytes], None]):
+ """
+ Compiles the input model and writes the serialized ONNX bytes to a stream using the provided write function.
+ Raises an 'InvalidArgument' exception if the compilation options are invalid.
+ :param write_function: A callable that accepts a bytes buffer to write.
+ """
+ self._model_compiler.compile_to_stream(write_function)
+
class IOBinding:
"""
@@ -1298,3 +1370,14 @@ def device_name(self) -> str:
Returns the name of the device where the SparseTensor data buffers reside e.g. cpu, cuda
"""
return self._tensor.device_name().lower()
+
+
+# Type hint for user-specified function that allows the user to specify initializer locations when compiling a model.
+GetInitializerLocationFunc = Callable[
+ [str, OrtValue, C.OrtExternalInitializerInfo | None], C.OrtExternalInitializerInfo | None
+]
+
+# Type hint that adheres to the signature expected by ORT.
+GetInitializerLocationWrapperFunc = Callable[
+ [str, C.OrtValue, C.OrtExternalInitializerInfo | None], C.OrtExternalInitializerInfo | None
+]
diff --git a/onnxruntime/python/onnxruntime_pybind_model_compiler.cc b/onnxruntime/python/onnxruntime_pybind_model_compiler.cc
index 69929cb68a775..6ff252b5d1353 100644
--- a/onnxruntime/python/onnxruntime_pybind_model_compiler.cc
+++ b/onnxruntime/python/onnxruntime_pybind_model_compiler.cc
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates
// Licensed under the MIT License.
+#if !defined(ORT_MINIMAL_BUILD)
#include "python/onnxruntime_pybind_model_compiler.h"
#include
@@ -8,11 +9,56 @@
#include
#include "core/common/common.h"
#include "core/framework/error_code_helper.h"
+#include "core/graph/abi_graph_types.h"
#include "core/session/utils.h"
namespace onnxruntime {
namespace python {
+///
+/// This function is called by ORT to allow the user to handle where every initializer is stored
+/// (i.e., externally or internally). This function wraps (and calls) the actual Python function
+/// provided by the user.
+///
+/// Opaque state that holds a pointer to the user's Python function.
+/// The name of the initializer to handle.
+/// The OrtValue with the initializer's data, type, and shape.
+/// The original external location of the initializer, if any. May be null.
+/// Output parameter set to the initializer's new external location. Function may
+/// return NULL if the initializer should be stored within the compiled ONNX model.
+/// A status indicating success or an error.
+static OrtStatus* ORT_API_CALL PyGetInitializerLocationFuncWrapper(
+ void* state,
+ const char* initializer_name,
+ const OrtValue* initializer_value,
+ const OrtExternalInitializerInfo* external_info,
+ /*out*/ OrtExternalInitializerInfo** new_external_info) {
+ PyGetInitializerLocationFunc* py_func = reinterpret_cast(state);
+ OrtStatus* status = nullptr;
+ std::shared_ptr py_new_external_info = nullptr;
+
+ // Call the Python function and convert any exceptions to a status.
+ ORT_TRY {
+ py_new_external_info = (*py_func)(initializer_name, *initializer_value, external_info);
+ }
+ ORT_CATCH(const std::exception& e) {
+ ORT_HANDLE_EXCEPTION([&]() {
+ status = ToOrtStatus(ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, e.what()));
+ });
+ }
+
+ if (py_new_external_info) {
+ // ORT expects to take ownership of the new external info, so make a copy because other Python code
+ // may be holding a reference to the `py_new_external_info`.
+ auto py_result_copy = std::make_unique(*py_new_external_info.get());
+ *new_external_info = py_result_copy.release();
+ } else {
+ *new_external_info = nullptr;
+ }
+
+ return status;
+}
+
onnxruntime::Status PyModelCompiler::Create(/*out*/ std::unique_ptr& out,
onnxruntime::Environment& env,
const PySessionOptions& sess_options,
@@ -21,8 +67,10 @@ onnxruntime::Status PyModelCompiler::Create(/*out*/ std::unique_ptr(env, sess_options, PrivateConstructorTag{});
+ GraphOptimizationLevel graph_optimization_level,
+ const PyGetInitializerLocationFunc& py_get_initializer_location_func) {
+ auto model_compiler = std::make_unique(env, sess_options, py_get_initializer_location_func,
+ PrivateConstructorTag{});
ModelCompilationOptions& compile_options = model_compiler->model_compile_options_;
if (input_model_is_path) {
@@ -46,6 +94,12 @@ onnxruntime::Status PyModelCompiler::Create(/*out*/ std::unique_ptrpy_get_initializer_location_func_) {
+ compile_options.SetOutputModelGetInitializerLocationFunc(
+ PyGetInitializerLocationFuncWrapper,
+ reinterpret_cast(&model_compiler->py_get_initializer_location_func_));
+ }
+
out = std::move(model_compiler);
return Status::OK();
}
@@ -80,9 +134,47 @@ onnxruntime::Status PyModelCompiler::CompileToBytes(std::string& output_buffer)
return Status::OK();
}
+///
+/// Function called by ORT to allow the user to write out the compiled ONNX model bytes to a custom output stream.
+/// This function wraps (and calls) the actual Python function provided by the user.
+///
+/// Opaque state that holds a pointer to the user's Python function.
+/// The buffer to write out. Contains a portion of the compiled ONNX model's bytes.
+/// The number of bytes in the buffer.
+/// A status indicating success or an error.
+static OrtStatus* ORT_API_CALL PyOutStreamWriteFuncWrapper(void* stream_state, const void* buffer,
+ size_t buffer_num_bytes) {
+ PyOutStreamWriteFunc* py_write_func = reinterpret_cast(stream_state);
+ OrtStatus* status = nullptr;
+
+ // Call the Python write function and convert any exceptions to a status.
+ ORT_TRY {
+ pybind11::bytes py_bytes(reinterpret_cast(buffer), buffer_num_bytes);
+ (*py_write_func)(py_bytes);
+ }
+ ORT_CATCH(const std::exception& e) {
+ ORT_HANDLE_EXCEPTION([&]() {
+ status = ToOrtStatus(ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, e.what()));
+ });
+ }
+
+ return status;
+}
+
+onnxruntime::Status PyModelCompiler::CompileToOutStream(PyOutStreamWriteFunc& write_func) {
+ model_compile_options_.SetOutputModelWriteFunc(PyOutStreamWriteFuncWrapper,
+ reinterpret_cast(&write_func));
+ ORT_RETURN_IF_ERROR(onnxruntime::CompileModel(env_, model_compile_options_));
+ return Status::OK();
+}
+
PyModelCompiler::PyModelCompiler(onnxruntime::Environment& env, const PySessionOptions& sess_options,
+ const PyGetInitializerLocationFunc& py_get_initializer_location_func,
PrivateConstructorTag)
- : env_(env), model_compile_options_(env, sess_options) {
+ : env_(env),
+ model_compile_options_(env, sess_options),
+ py_get_initializer_location_func_(py_get_initializer_location_func) {
}
} // namespace python
} // namespace onnxruntime
+#endif // !defined(ORT_MINIMAL_BUILD)
diff --git a/onnxruntime/python/onnxruntime_pybind_model_compiler.h b/onnxruntime/python/onnxruntime_pybind_model_compiler.h
index d770dbf65cc10..957350accdba2 100644
--- a/onnxruntime/python/onnxruntime_pybind_model_compiler.h
+++ b/onnxruntime/python/onnxruntime_pybind_model_compiler.h
@@ -3,7 +3,6 @@
// Licensed under the MIT License.
#pragma once
-#if !defined(ORT_MINIMAL_BUILD)
#include
#include
#include "core/common/status.h"
@@ -14,11 +13,24 @@ namespace onnxruntime {
class Environment;
namespace python {
+// Type of the function provided by Python code that is called by ORT to write out the compiled model.
+using PyOutStreamWriteFunc = std::function;
+
+// Type of the function provided by Python code that is called by ORT to handle every initializer.
+using PyGetInitializerLocationFunc = std::function(
+ const std::string& initializer_name,
+ const OrtValue& initializer_value,
+ const OrtExternalInitializerInfo* external_info)>;
+
///
/// Class exposed to Python that enables compiling ONNX models.
/// Internally wraps a onnxruntime::ModelCompilationOptions that stores and validates settings.
///
class PyModelCompiler {
+#if defined(ORT_MINIMAL_BUILD)
+ public:
+ bool not_defined_in_this_build{}; // Prevent empty class warning.
+#else
private:
// private tag to pass to constructor to ensure that constructor cannot be directly called externally
struct PrivateConstructorTag {};
@@ -40,6 +52,7 @@ class PyModelCompiler {
/// Flags from OrtCompileApiFlags
/// Optimization level for graph transformations on the model.
/// Defaults to ORT_DISABLE_ALL to allow EP to get the original loaded model.
+ /// User's function to handle saving of initializers.
/// A Status indicating error or success.
static onnxruntime::Status Create(/*out*/ std::unique_ptr& out,
onnxruntime::Environment& env,
@@ -49,11 +62,13 @@ class PyModelCompiler {
const std::string& external_initializers_file_path = {},
size_t external_initializers_size_threshold = 1024,
uint32_t flags = 0,
- GraphOptimizationLevel graph_opt_level = GraphOptimizationLevel::ORT_DISABLE_ALL);
+ GraphOptimizationLevel graph_opt_level = GraphOptimizationLevel::ORT_DISABLE_ALL,
+ const PyGetInitializerLocationFunc& py_get_initializer_location_func = nullptr);
// Note: Creation should be done via Create(). This constructor is public so that it can be called from
// std::make_shared().
PyModelCompiler(onnxruntime::Environment& env, const PySessionOptions& sess_options,
+ const PyGetInitializerLocationFunc& py_get_initializer_location_func,
PrivateConstructorTag);
///
@@ -73,11 +88,19 @@ class PyModelCompiler {
/// A Status indicating error or success.
onnxruntime::Status CompileToBytes(std::string& output_buffer);
+ ///
+ /// Compiles the input model and writes the result into the provided output stream (write functor).
+ ///
+ /// Write functor that encapsulates the stream's state.
+ /// A Status indicating error or success.
+ onnxruntime::Status CompileToOutStream(PyOutStreamWriteFunc& write_func);
+
private:
onnxruntime::Environment& env_;
onnxruntime::ModelCompilationOptions model_compile_options_;
std::string input_model_bytes_;
+ PyGetInitializerLocationFunc py_get_initializer_location_func_;
+#endif // defined(ORT_MINIMAL_BUILD)
};
} // namespace python
} // namespace onnxruntime
-#endif // !defined(ORT_MINIMAL_BUILD)
diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc
index 27c76f7f5c482..e370518b1fffb 100644
--- a/onnxruntime/python/onnxruntime_pybind_state.cc
+++ b/onnxruntime/python/onnxruntime_pybind_state.cc
@@ -6,11 +6,8 @@
#include
#include "python/onnxruntime_pybind_exceptions.h"
#include "python/onnxruntime_pybind_mlvalue.h"
-#include "python/onnxruntime_pybind_state_common.h"
-
-#if !defined(ORT_MINIMAL_BUILD)
#include "python/onnxruntime_pybind_model_compiler.h"
-#endif // !defined(ORT_MINIMAL_BUILD)
+#include "python/onnxruntime_pybind_state_common.h"
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#define PY_ARRAY_UNIQUE_SYMBOL onnxruntime_python_ARRAY_API
@@ -45,6 +42,7 @@
#include "core/session/lora_adapters.h"
#if !defined(ORT_MINIMAL_BUILD)
+#include "core/graph/abi_graph_types.h"
#include "core/session/abi_devices.h"
#include "core/session/plugin_ep/ep_factory_internal.h"
#include "core/session/provider_policy_context.h"
@@ -2730,6 +2728,35 @@ including arg name, arg type (contains both type and shape).)pbdoc")
.value("kSameAsRequested", onnxruntime::ArenaExtendStrategy::kSameAsRequested)
.export_values();
+ // Must use a std::shared_ptr to hold OrtExternalInitializerInfo because the same instances is passed
+ // between C++ and Python (and Python cannot transfer ownership to C++).
+ py::class_> ort_external_initializer_info_binding(
+ m, "OrtExternalInitializerInfo",
+ R"pbdoc(Location information for initializer data stored in an external file)pbdoc");
+ ort_external_initializer_info_binding
+ .def(py::init([](const std::basic_string& filepath, int64_t file_offset, size_t byte_size) {
+#if !defined(ORT_MINIMAL_BUILD)
+ return std::make_shared(filepath, file_offset, byte_size);
+#else
+ ORT_UNUSED_PARAMETER(filepath);
+ ORT_UNUSED_PARAMETER(file_offset);
+ ORT_UNUSED_PARAMETER(byte_size);
+ ORT_THROW("OrtExternalInitializerInfo creation is not supported in this build");
+#endif
+ }))
+ .def_property_readonly(
+ "filepath",
+ [](OrtExternalInitializerInfo* info) -> std::basic_string { return info->GetRelPath(); },
+ R"pbdoc(The relative path to the file in which initializer data is stored.)pbdoc")
+ .def_property_readonly(
+ "file_offset",
+ [](OrtExternalInitializerInfo* info) -> int64_t { return info->GetOffset(); },
+ R"pbdoc(The file byte offset where the initializer data is stored.)pbdoc")
+ .def_property_readonly(
+ "byte_size",
+ [](OrtExternalInitializerInfo* info) -> size_t { return info->GetLength(); },
+ R"pbdoc(The byte size of the initializer data in the file.)pbdoc");
+
py::enum_(m, "OrtCompileApiFlags", py::arithmetic())
.value("NONE", OrtCompileApiFlags_NONE)
.value("ERROR_IF_NO_NODES_COMPILED", OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED)
@@ -2744,7 +2771,8 @@ including arg name, arg type (contains both type and shape).)pbdoc")
std::string external_initializers_file_path = {},
size_t external_initializers_size_threshold = 1024,
uint32_t flags = OrtCompileApiFlags_NONE,
- GraphOptimizationLevel graph_optimization_level = GraphOptimizationLevel::ORT_DISABLE_ALL) {
+ GraphOptimizationLevel graph_optimization_level = GraphOptimizationLevel::ORT_DISABLE_ALL,
+ const PyGetInitializerLocationFunc& py_get_initializer_location_func = nullptr) {
#if !defined(ORT_MINIMAL_BUILD)
std::unique_ptr result;
OrtPybindThrowIfError(PyModelCompiler::Create(result, GetEnv(), sess_options,
@@ -2752,7 +2780,8 @@ including arg name, arg type (contains both type and shape).)pbdoc")
embed_compiled_data_into_model,
external_initializers_file_path,
external_initializers_size_threshold,
- flags, graph_optimization_level));
+ flags, graph_optimization_level,
+ py_get_initializer_location_func));
return result;
#else
ORT_UNUSED_PARAMETER(sess_options);
@@ -2763,6 +2792,7 @@ including arg name, arg type (contains both type and shape).)pbdoc")
ORT_UNUSED_PARAMETER(external_initializers_size_threshold);
ORT_UNUSED_PARAMETER(flags);
ORT_UNUSED_PARAMETER(graph_optimization_level);
+ ORT_UNUSED_PARAMETER(py_get_initializer_location_func);
ORT_THROW("Compile API is not supported in this build.");
#endif
}))
@@ -2790,7 +2820,19 @@ including arg name, arg type (contains both type and shape).)pbdoc")
ORT_THROW("Compile API is not supported in this build.");
#endif
},
- R"pbdoc(Compile an ONNX model into a buffer.)pbdoc");
+ R"pbdoc(Compile an ONNX model into a buffer.)pbdoc")
+ .def(
+ "compile_to_stream",
+ [](PyModelCompiler* model_compiler, PyOutStreamWriteFunc& py_stream_write_func) {
+#if !defined(ORT_MINIMAL_BUILD)
+ OrtPybindThrowIfError(model_compiler->CompileToOutStream(py_stream_write_func));
+#else
+ ORT_UNUSED_PARAMETER(model_compiler);
+ ORT_UNUSED_PARAMETER(py_stream_write_func);
+ ORT_THROW("Compile API is not supported in this build.");
+#endif
+ },
+ R"pbdoc(Compile an ONNX model into an output stream using the provided write functor.)pbdoc");
}
bool InitArray() {
diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc
index 6ad21fa9f5cf5..a9d6273ae2f20 100644
--- a/onnxruntime/test/framework/session_state_test.cc
+++ b/onnxruntime/test/framework/session_state_test.cc
@@ -11,6 +11,7 @@
#include "core/framework/kernel_registry.h"
#include "core/framework/op_kernel.h"
#include "core/framework/bfc_arena.h"
+#include "core/framework/ep_context_options.h"
#include "core/framework/session_state.h"
#include "core/graph/graph_utils.h"
#include "core/graph/graph_viewer.h"
@@ -504,7 +505,7 @@ void LoadWithResourceAwarePartitioning(const ORTCHAR_T* model_path,
ASSERT_STATUS_OK(
partitioner.Partition(graph, session_state.GetMutableFuncMgr(), transform_layout_fn,
sess_options.config_options, default_logger, GraphPartitioner::Mode::kNormal,
- EpContextModelGenerationOptions{},
+ epctx::ModelGenOptions{},
debug_graph_fn));
verifier_fn(graph);
diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc
index a42a56492b04a..1c8cc6f78fe63 100644
--- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc
+++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc
@@ -2077,6 +2077,278 @@ TEST_F(QnnHTPBackendTests, QnnEpDynamicOptions) {
EXPECT_STREQ("Unsupported EP Dynamic Option", e.what());
}
}
+
+// Implementation of OrtOutStreamWriteFunc that writes the compiled model to a file.
+static OrtStatus* ORT_API_CALL TestWriteToStream(void* stream_state, const void* buffer, size_t buffer_num_bytes) {
+ std::ofstream* outfile = reinterpret_cast(stream_state);
+ outfile->write(reinterpret_cast(buffer), buffer_num_bytes);
+ return nullptr; // No error
+}
+
+// Implementation of OrtOutStreamWriteFunc that directly returns an OrtStatus indicating an error.
+static OrtStatus* ORT_API_CALL ReturnStatusFromStream(void* stream_state, const void* buffer, size_t buffer_num_bytes) {
+ ORT_UNUSED_PARAMETER(stream_state);
+ ORT_UNUSED_PARAMETER(buffer);
+ ORT_UNUSED_PARAMETER(buffer_num_bytes);
+ return Ort::GetApi().CreateStatus(ORT_FAIL, "Error from OrtOutStreamWriteFunc callback");
+}
+
+// Test using the CompileModel() API with settings:
+// - input model comes from a file
+// - write output model to custom write stream
+TEST_F(QnnHTPBackendTests, CompileApi_InputFile_WriteOutputModelBytes) {
+ const ORTCHAR_T* input_model_file = ORT_TSTR("./compileapi_inputfile_writeoutputmodelbytes.onnx");
+ std::filesystem::remove(input_model_file);
+
+ // Create a test model and save it to a file.
+ TestModel test_model;
+ CreateTestModel(BuildGraphWithQAndNonQ(false), 21, logging::Severity::kERROR, test_model);
+ ASSERT_STATUS_OK(test_model.Save(input_model_file));
+
+ // Initialize session options with QNN EP
+ Ort::SessionOptions so;
+ ProviderOptions provider_options;
+ provider_options["backend_type"] = "htp";
+ provider_options["offload_graph_io_quantization"] = "0";
+ so.AppendExecutionProvider("QNN", provider_options);
+
+ const ORTCHAR_T* output_model_file = ORT_TSTR("compileapi_inputfile_writeoutputmodelbytes_ctx.onnx");
+ std::filesystem::remove(output_model_file);
+
+ // Open an output file. Test will incrementally write the output model to file
+ // via calls to our OrtOutStreamWriteFunc callback.
+ ASSERT_FALSE(std::filesystem::exists(output_model_file));
+ std::ofstream outfile(output_model_file, std::ios::binary);
+
+ // Create model compilation options from the session options.
+ Ort::ModelCompilationOptions compile_options(*ort_env, so);
+ compile_options.SetInputModelPath(input_model_file);
+ compile_options.SetOutputModelWriteFunc(TestWriteToStream, reinterpret_cast(&outfile));
+ compile_options.SetEpContextEmbedMode(true);
+ compile_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC);
+
+ // Compile the model.
+ Ort::Status status = Ort::CompileModel(*ort_env, compile_options);
+ ASSERT_TRUE(status.IsOK()) << status.GetErrorMessage();
+ outfile.flush();
+ outfile.close();
+
+ // Check that the compiled model has the expected number of EPContext nodes.
+ ASSERT_TRUE(std::filesystem::exists(output_model_file));
+ CheckEpContextNodeCounts(output_model_file, 2, 2);
+}
+
+// Tests using an OrtOutStreamFunc function that returns an error.
+TEST_F(QnnHTPBackendTests, CompileApi_OutputStream_ReturnStatus) {
+ // Create a test model (in memory).
+ TestModel test_model;
+ CreateTestModel(BuildGraphWithQAndNonQ(false), 21, logging::Severity::kERROR, test_model);
+ std::string model_data = test_model.Serialize();
+
+ // Initialize session options with QNN EP
+ Ort::SessionOptions so;
+ ProviderOptions provider_options;
+ provider_options["backend_type"] = "htp";
+ provider_options["offload_graph_io_quantization"] = "0";
+ so.AppendExecutionProvider("QNN", provider_options);
+
+ // Create model compilation options from the session options.
+ Ort::ModelCompilationOptions compile_options(*ort_env, so);
+ compile_options.SetInputModelFromBuffer(reinterpret_cast(model_data.data()), model_data.size());
+ compile_options.SetOutputModelWriteFunc(ReturnStatusFromStream, nullptr); // Set output stream that returns error
+ compile_options.SetEpContextEmbedMode(true);
+
+ // Compile the model. Expect a specific error status.
+ Ort::Status status = Ort::CompileModel(*ort_env, compile_options);
+ ASSERT_FALSE(status.IsOK());
+ EXPECT_EQ(status.GetErrorCode(), ORT_FAIL);
+ EXPECT_EQ(status.GetErrorMessage(), "Error from OrtOutStreamWriteFunc callback");
+}
+
+struct CustomInitializerHandlerState {
+ const ORTCHAR_T* external_file_path = nullptr;
+ std::ofstream* outfile = nullptr;
+};
+
+static OrtStatus* ORT_API_CALL TestHandleInitializerDataFunc(void* state,
+ const char* initializer_name,
+ const OrtValue* c_initializer_value,
+ const OrtExternalInitializerInfo* /*c_external_info*/,
+ OrtExternalInitializerInfo** c_new_external_info) {
+ Ort::Status final_status{nullptr};
+
+ ORT_TRY {
+ CustomInitializerHandlerState* custom_state = reinterpret_cast(state);
+
+ if (std::string("constant") == initializer_name) {
+ // Keep a specific initializer in the model just to test both scenarios.
+ // A real implementation may check the byte size and keep small initializers in the model.
+ *c_new_external_info = nullptr;
+ return nullptr;
+ }
+
+ //
+ // Store other initializers in an external file.
+ //
+ Ort::ConstValue value{c_initializer_value};
+ size_t byte_size = value.GetTensorSizeInBytes();
+ int64_t offset = custom_state->outfile->tellp();
+ const ORTCHAR_T* location = custom_state->external_file_path;
+
+ custom_state->outfile->write(static_cast(value.GetTensorRawData()), byte_size);
+ custom_state->outfile->flush();
+
+ // Provide caller (ORT) with the new external info.
+ Ort::ExternalInitializerInfo new_external_info{nullptr};
+ if (Ort::Status status = Ort::ExternalInitializerInfo::Create(location, offset, byte_size, new_external_info);
+ !status.IsOK()) {
+ return status.release();
+ }
+
+ *c_new_external_info = new_external_info.release();
+ }
+ ORT_CATCH(const Ort::Exception& ex) {
+ ORT_HANDLE_EXCEPTION(([&ex, &final_status]() {
+ final_status = Ort::Status{ex};
+ }));
+ }
+ ORT_CATCH(const std::exception& ex) {
+ ORT_HANDLE_EXCEPTION(([&ex, &final_status]() {
+ final_status = Ort::Status(ex.what(), ORT_FAIL);
+ }));
+ }
+
+ return final_status.release();
+}
+
+// Test using the CompileModel() API with settings:
+// - input model comes from a file
+// - write output model to a file
+// - Use callback to specify where each initializer is stored (i.e., external file or within model).
+TEST_F(QnnHTPBackendTests, CompileApi_InputFile_OutputFile_InitializerHandler) {
+ const ORTCHAR_T* input_model_file = ORT_TSTR("./compileapi_inputfile_outputfile_initializerhandler.onnx");
+ const ORTCHAR_T* output_model_file = ORT_TSTR("./compileapi_inputfile_outputfile_initializerhandler_ctx.onnx");
+ const ORTCHAR_T* initializer_file = ORT_TSTR("./compileapi_inputfile_outputfile_initializerhandler.bin");
+ std::filesystem::remove(input_model_file);
+ std::filesystem::remove(output_model_file);
+ std::filesystem::remove(initializer_file);
+
+ // Create a test model and save it to a file.
+ TestModel test_model;
+ CreateTestModel(BuildGraphWithQAndNonQ(false), 21, logging::Severity::kERROR, test_model);
+ ASSERT_STATUS_OK(test_model.Save(input_model_file));
+
+ // Initialize session options with QNN EP
+ Ort::SessionOptions so;
+ ProviderOptions provider_options;
+ provider_options["backend_type"] = "htp";
+ provider_options["offload_graph_io_quantization"] = "0";
+ so.AppendExecutionProvider("QNN", provider_options);
+
+ // Open a file to store external initializers. ORT will call our handler function for every initializer.
+ ASSERT_FALSE(std::filesystem::exists(initializer_file));
+ std::ofstream outfile(initializer_file, std::ios::binary);
+ CustomInitializerHandlerState custom_state = {initializer_file, &outfile};
+
+ // Create model compilation options from the session options.
+ Ort::ModelCompilationOptions compile_options(*ort_env, so);
+ compile_options.SetInputModelPath(input_model_file);
+ compile_options.SetOutputModelPath(output_model_file);
+ compile_options.SetOutputModelGetInitializerLocationFunc(TestHandleInitializerDataFunc,
+ reinterpret_cast(&custom_state));
+ compile_options.SetEpContextEmbedMode(true);
+ compile_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC);
+
+ // Compile the model.
+ Ort::Status status = Ort::CompileModel(*ort_env, compile_options);
+ ASSERT_TRUE(status.IsOK()) << status.GetErrorMessage();
+ outfile.flush();
+ outfile.close();
+
+ ASSERT_TRUE(std::filesystem::exists(initializer_file));
+ ASSERT_TRUE(std::filesystem::exists(output_model_file));
+ CheckEpContextNodeCounts(output_model_file, 2, 2);
+}
+
+static OrtStatus* ORT_API_CALL ReuseExternalInitializers(void* state,
+ const char* /*initializer_name*/,
+ const OrtValue* /*initializer_value*/,
+ const OrtExternalInitializerInfo* external_info,
+ OrtExternalInitializerInfo** new_external_info) {
+ Ort::Status final_status{nullptr};
+
+ ORT_TRY {
+ // If the original initializer was stored in an external file, keep it there (just for testing).
+ if (external_info != nullptr) {
+ Ort::ConstExternalInitializerInfo info(external_info);
+ auto location = info.GetFilePath();
+ int64_t offset = info.GetFileOffset();
+ size_t byte_size = info.GetByteSize();
+
+ Ort::ExternalInitializerInfo new_info(nullptr);
+ Ort::Status status = Ort::ExternalInitializerInfo::Create(location.c_str(), offset, byte_size, new_info);
+ if (!status.IsOK()) {
+ return status.release();
+ }
+
+ *new_external_info = new_info.release();
+
+ // Keep track of number of reused external initializers so that we can assert
+ // that we reused the expected number of initializers.
+ // THIS IS TEST CODE. An application would not do this.
+ size_t* num_reused_ext_initializers = reinterpret_cast(state);
+ *num_reused_ext_initializers += 1;
+
+ return nullptr;
+ }
+
+ // If not originally external, save it within the generated compiled model
+ *new_external_info = nullptr;
+ }
+ ORT_CATCH(const Ort::Exception& ex) {
+ ORT_HANDLE_EXCEPTION(([&ex, &final_status]() {
+ final_status = Ort::Status{ex};
+ }));
+ }
+ ORT_CATCH(const std::exception& ex) {
+ ORT_HANDLE_EXCEPTION(([&ex, &final_status]() {
+ final_status = Ort::Status(ex.what(), ORT_FAIL);
+ }));
+ }
+
+ return final_status.release();
+}
+
+// Test using the CompileModel() API with settings:
+// - input model comes from a file
+// - write output model to a file
+// - Use callback to specify where each initializer is stored. We'll reuse external initializers
+// from original model!
+TEST_F(QnnHTPBackendTests, CompileApi_InitializerHandler_ReuseExternalInitializers) {
+ const ORTCHAR_T* input_model_file = ORT_TSTR("testdata/conv_qdq_external_ini.onnx");
+ const ORTCHAR_T* output_model_file = ORT_TSTR("testdata/conv_qdq_external_ini_reuse_ctx.onnx");
+ std::filesystem::remove(output_model_file);
+
+ size_t num_reused_ext_initializers = 0;
+
+ // Create model compilation options from the session options.
+ Ort::SessionOptions so;
+ Ort::ModelCompilationOptions compile_options(*ort_env, so);
+ compile_options.SetInputModelPath(input_model_file);
+ compile_options.SetOutputModelPath(output_model_file);
+ compile_options.SetOutputModelGetInitializerLocationFunc(ReuseExternalInitializers,
+ reinterpret_cast(&num_reused_ext_initializers));
+ compile_options.SetEpContextEmbedMode(true);
+
+ // Compile the model.
+ Ort::Status status = Ort::CompileModel(*ort_env, compile_options);
+ ASSERT_TRUE(status.IsOK()) << status.GetErrorMessage();
+ ASSERT_TRUE(std::filesystem::exists(output_model_file));
+ std::filesystem::remove(output_model_file);
+
+ ASSERT_EQ(num_reused_ext_initializers, 2); // Reused external conv weight and bias.
+}
+
#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)
} // namespace test
diff --git a/onnxruntime/test/python/onnxruntime_test_python_compile_api.py b/onnxruntime/test/python/onnxruntime_test_python_compile_api.py
index ed3cd882d7e00..e46cdb4f98850 100644
--- a/onnxruntime/test/python/onnxruntime_test_python_compile_api.py
+++ b/onnxruntime/test/python/onnxruntime_test_python_compile_api.py
@@ -275,6 +275,149 @@ def compile_and_get_op_counts(
)
self.assertEqual(op_counts_1["Cast"], 8)
+ def test_compile_from_file_to_stream(self):
+ """
+ Tests compiling a model (from files) to an output stream using a custom write functor.
+ """
+ provider = None
+ provider_options = dict()
+ if "QNNExecutionProvider" in available_providers:
+ provider = "QNNExecutionProvider"
+ provider_options["backend_type"] = "htp"
+
+ input_model_path = get_name("nhwc_resize_scales_opset18.onnx")
+ output_model_path = os.path.join(self._tmp_dir_path, "model.compiled.stream.onnx")
+
+ with open(output_model_path, "wb") as output_fd:
+ # User's custom write functor. Writes the model to a file.
+ def my_write_func(buffer: bytes):
+ self.assertGreater(len(buffer), 0)
+ output_fd.write(buffer)
+
+ session_options = onnxrt.SessionOptions()
+ if provider:
+ session_options.add_provider(provider, provider_options)
+
+ model_compiler = onnxrt.ModelCompiler(
+ session_options,
+ input_model_path,
+ embed_compiled_data_into_model=True,
+ external_initializers_file_path=None,
+ )
+ model_compiler.compile_to_stream(my_write_func)
+
+ self.assertTrue(os.path.exists(output_model_path))
+
+ def test_compile_to_stream_that_raises_exception(self):
+ """
+ Tests compiling a model to an output stream that always raises an exception.
+ """
+ input_model_path = get_name("nhwc_resize_scales_opset18.onnx")
+
+ # User's custom write functor that raises an exception.
+ test_py_error_message = "My Python Error"
+
+ def my_write_func(buffer: bytes):
+ self.assertGreater(len(buffer), 0)
+ raise ValueError(test_py_error_message)
+
+ session_options = onnxrt.SessionOptions()
+ model_compiler = onnxrt.ModelCompiler(
+ session_options,
+ input_model_path,
+ embed_compiled_data_into_model=True,
+ external_initializers_file_path=None,
+ )
+
+ # Try to compile and expect ORT to raise a Fail exception that contains our message.
+ with self.assertRaises(Fail) as context:
+ model_compiler.compile_to_stream(my_write_func)
+ self.assertIn(test_py_error_message, str(context.exception))
+
+ def test_compile_with_basic_initializer_location_func(self):
+ """
+ Tests compiling a model using a custom initializer handler that stores initializers
+ in an external file.
+ """
+ input_model_path = get_name("conv_qdq_external_ini.onnx")
+ output_model_path = os.path.join(self._tmp_dir_path, "conv_qdq.init_handler.onnx")
+ initializer_file_path = os.path.join(self._tmp_dir_path, "conv_qdq.init_handler.bin")
+
+ if os.path.exists(output_model_path):
+ os.remove(output_model_path)
+
+ if os.path.exists(initializer_file_path):
+ os.remove(initializer_file_path)
+
+ with open(initializer_file_path, "wb") as ext_init_file:
+
+ def store_large_initializer_externally(
+ initializer_name: str,
+ initializer_value: onnxrt.OrtValue,
+ external_info: onnxrt.OrtExternalInitializerInfo | None,
+ ) -> onnxrt.OrtExternalInitializerInfo | None:
+ self.assertTrue(initializer_name) # Should have valid name
+ byte_size = initializer_value.tensor_size_in_bytes()
+
+ if byte_size < 64:
+ return None # Store small initializer within compiled model.
+
+ # Else, write initializer to new external file.
+ value_np = initializer_value.numpy()
+ file_offset = ext_init_file.tell()
+ ext_init_file.write(value_np.tobytes())
+ return onnxrt.OrtExternalInitializerInfo(initializer_file_path, file_offset, byte_size)
+
+ session_options = onnxrt.SessionOptions()
+ model_compiler = onnxrt.ModelCompiler(
+ session_options,
+ input_model_path,
+ embed_compiled_data_into_model=True,
+ external_initializers_file_path=None,
+ get_initializer_location_func=store_large_initializer_externally,
+ )
+ model_compiler.compile_to_file(output_model_path)
+
+ self.assertTrue(os.path.exists(output_model_path))
+ self.assertTrue(os.path.exists(initializer_file_path))
+
+ def test_compile_with_initializer_func_that_reuses(self):
+ """
+ Tests compiling a model using a custom initializer handler that reuses external initializer files.
+ """
+ input_model_path = get_name("conv_qdq_external_ini.onnx")
+ output_model_path = os.path.join(self._tmp_dir_path, "conv_qdq.init_handler_reuse.onnx")
+
+ if os.path.exists(output_model_path):
+ os.remove(output_model_path)
+
+ # Function that reuses external initializer files for the compiled model.
+ def reuse_external_initializers(
+ initializer_name: str,
+ initializer_value: onnxrt.OrtValue,
+ external_info: onnxrt.OrtExternalInitializerInfo | None,
+ ) -> onnxrt.OrtExternalInitializerInfo | None:
+ self.assertTrue(initializer_name) # Should have valid name
+ self.assertNotEqual(initializer_value.data_ptr(), 0)
+ self.assertGreater(initializer_value.tensor_size_in_bytes(), 0)
+ if external_info is not None:
+ # Original initializer is stored externally.
+ # Make the initializer in the compiled model use the same external file
+ return external_info
+
+ return None # Otherwise, make a copy of the initializer and store it within compiled model.
+
+ session_options = onnxrt.SessionOptions()
+ model_compiler = onnxrt.ModelCompiler(
+ session_options,
+ input_model_path,
+ embed_compiled_data_into_model=True,
+ external_initializers_file_path=None,
+ get_initializer_location_func=reuse_external_initializers,
+ )
+ model_compiler.compile_to_file(output_model_path)
+ self.assertTrue(os.path.exists(output_model_path))
+
def test_fail_load_uncompiled_model_and_then_compile(self):
"""
Tests compiling scenario: