diff --git a/VERSION_NUMBER b/VERSION_NUMBER index a6c2798a482eb..49e0a31d4964d 100644 --- a/VERSION_NUMBER +++ b/VERSION_NUMBER @@ -1 +1 @@ -1.23.0 +1.23.1 diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 6aad71e40b2a8..b23365e99c2d7 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1800,6 +1800,7 @@ endif() if (WIN32 AND onnxruntime_BUILD_SHARED_LIB AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND NOT onnxruntime_MINIMAL_BUILD) + # example_plugin_ep file(GLOB onnxruntime_autoep_test_library_src "${TEST_SRC_DIR}/autoep/library/*.h" "${TEST_SRC_DIR}/autoep/library/*.cc") onnxruntime_add_shared_library_module(example_plugin_ep ${onnxruntime_autoep_test_library_src}) @@ -1822,6 +1823,9 @@ if (WIN32 AND onnxruntime_BUILD_SHARED_LIB AND set_property(TARGET example_plugin_ep APPEND_STRING PROPERTY LINK_FLAGS ${ONNXRUNTIME_AUTOEP_LIB_LINK_FLAG}) + set_target_properties(example_plugin_ep PROPERTIES FOLDER "ONNXRuntimeTest") + source_group(TREE ${TEST_SRC_DIR} FILES ${onnxruntime_autoep_test_library_src}) + # test library file(GLOB onnxruntime_autoep_test_SRC "${ONNXRUNTIME_AUTOEP_TEST_SRC_DIR}/*.h" "${ONNXRUNTIME_AUTOEP_TEST_SRC_DIR}/*.cc") diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs index 792f0ddd0f777..79e6dbbb11c89 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs @@ -214,6 +214,82 @@ public IReadOnlyDictionary OverridableInitializerMetadata } } + /// + /// Fetches memory info for all inputs in the same order as their names. + /// (See InputNames property). + /// + /// A disposable readonly collection of OrtMemoryInfo + public IDisposableReadOnlyCollection GetMemoryInfosForInputs() + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetInputCount(_nativeHandle, out UIntPtr numInputs)); + + if(numInputs == UIntPtr.Zero) + { + return new DisposableList(); + } + + var memoryInfoArray = new IntPtr[(ulong)numInputs]; + + NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetMemoryInfoForInputs(_nativeHandle, + memoryInfoArray, numInputs)); + + return new DisposableList( + memoryInfoArray.Select(static ptr => new OrtMemoryInfo(ptr, /* owned= */ false))); + } + + /// + /// Fetches memory info for all outputs in the same order as their names. + /// (See OutputNames property). + /// + /// A disposable readonly collection of OrtMemoryInfo + public IDisposableReadOnlyCollection GetMemoryInfosForOutputs() + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOutputCount(_nativeHandle, + out UIntPtr numOutputs)); + + if(numOutputs == UIntPtr.Zero) + { + return new DisposableList(); + } + + var memoryInfoArray = new IntPtr[(ulong)numOutputs]; + + NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetMemoryInfoForOutputs(_nativeHandle, + memoryInfoArray, numOutputs)); + return new DisposableList( + memoryInfoArray.Select(static ptr => new OrtMemoryInfo(ptr, /* owned= */ false))); + } + + /// + /// Fetches OrtEpDevice instances for all inputs in the same order as their input names. + /// For inputs that do not have a device, the corresponding entry in the returned list is null. + /// See InputNames property. + /// + /// IReadOnlyList + public IReadOnlyList GetEpDeviceForInputs() + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetInputCount(_nativeHandle, + out UIntPtr numInputs)); + + if (numInputs == UIntPtr.Zero) + { + // OrtSessionGetEpDeviceForInputs expects numInputs > 0, otherwise it is an invalid arg. + return []; + } + + var epDevicesForInputs = new IntPtr[(ulong)numInputs]; + + NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetEpDeviceForInputs(_nativeHandle, + epDevicesForInputs, numInputs)); + + // Some entries in epDevicesForInputs can be IntPtr.Zero, indicating the input does not + // have a device; return null for those entries. + return epDevicesForInputs + .Select(static ptr => ptr == IntPtr.Zero ? null : new OrtEpDevice(ptr)) + .ToList() + .AsReadOnly(); + } + /// /// Runs the loaded model for the given inputs, and fetches all the outputs. /// diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index 53880308da261..b97adfbd564d5 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -500,7 +500,7 @@ static NativeMethods() OrtCreateEnvWithGlobalThreadPools = (DOrtCreateEnvWithGlobalThreadPools)Marshal.GetDelegateForFunctionPointer(api_.CreateEnvWithGlobalThreadPools, typeof(DOrtCreateEnvWithGlobalThreadPools)); OrtCreateEnvWithCustomLoggerAndGlobalThreadPools = (DOrtCreateEnvWithCustomLoggerAndGlobalThreadPools)Marshal.GetDelegateForFunctionPointer(api_.CreateEnvWithCustomLoggerAndGlobalThreadPools, typeof(DOrtCreateEnvWithCustomLoggerAndGlobalThreadPools)); OrtReleaseEnv = (DOrtReleaseEnv)Marshal.GetDelegateForFunctionPointer(api_.ReleaseEnv, typeof(DOrtReleaseEnv)); - + OrtEnableTelemetryEvents = (DOrtEnableTelemetryEvents)Marshal.GetDelegateForFunctionPointer(api_.EnableTelemetryEvents, typeof(DOrtEnableTelemetryEvents)); OrtDisableTelemetryEvents = (DOrtDisableTelemetryEvents)Marshal.GetDelegateForFunctionPointer(api_.DisableTelemetryEvents, typeof(DOrtDisableTelemetryEvents)); @@ -527,6 +527,9 @@ static NativeMethods() OrtSessionGetInputTypeInfo = (DOrtSessionGetInputTypeInfo)Marshal.GetDelegateForFunctionPointer(api_.SessionGetInputTypeInfo, typeof(DOrtSessionGetInputTypeInfo)); OrtSessionGetOutputTypeInfo = (DOrtSessionGetOutputTypeInfo)Marshal.GetDelegateForFunctionPointer(api_.SessionGetOutputTypeInfo, typeof(DOrtSessionGetOutputTypeInfo)); OrtSessionGetOverridableInitializerTypeInfo = (DOrtSessionGetOverridableInitializerTypeInfo)Marshal.GetDelegateForFunctionPointer(api_.SessionGetOverridableInitializerTypeInfo, typeof(DOrtSessionGetOverridableInitializerTypeInfo)); + OrtSessionGetMemoryInfoForInputs = (DOrtSessionGetMemoryInfoForInputs)Marshal.GetDelegateForFunctionPointer(api_.SessionGetMemoryInfoForInputs, typeof(DOrtSessionGetMemoryInfoForInputs)); + OrtSessionGetMemoryInfoForOutputs = (DOrtSessionGetMemoryInfoForOutputs)Marshal.GetDelegateForFunctionPointer(api_.SessionGetMemoryInfoForOutputs, typeof(DOrtSessionGetMemoryInfoForOutputs)); + OrtSessionGetEpDeviceForInputs = (DOrtSessionGetEpDeviceForInputs)Marshal.GetDelegateForFunctionPointer(api_.SessionGetEpDeviceForInputs, typeof(DOrtSessionGetEpDeviceForInputs)); OrtReleaseTypeInfo = (DOrtReleaseTypeInfo)Marshal.GetDelegateForFunctionPointer(api_.ReleaseTypeInfo, typeof(DOrtReleaseTypeInfo)); OrtReleaseSession = (DOrtReleaseSession)Marshal.GetDelegateForFunctionPointer(api_.ReleaseSession, typeof(DOrtReleaseSession)); OrtSessionGetProfilingStartTimeNs = (DOrtSessionGetProfilingStartTimeNs)Marshal.GetDelegateForFunctionPointer(api_.SessionGetProfilingStartTimeNs, typeof(DOrtSessionGetProfilingStartTimeNs)); @@ -588,6 +591,9 @@ static NativeMethods() OrtMemoryInfoGetMemType = (DOrtMemoryInfoGetMemType)Marshal.GetDelegateForFunctionPointer(api_.MemoryInfoGetMemType, typeof(DOrtMemoryInfoGetMemType)); OrtMemoryInfoGetType = (DOrtMemoryInfoGetType)Marshal.GetDelegateForFunctionPointer(api_.MemoryInfoGetType, typeof(DOrtMemoryInfoGetType)); OrtGetAllocatorWithDefaultOptions = (DOrtGetAllocatorWithDefaultOptions)Marshal.GetDelegateForFunctionPointer(api_.GetAllocatorWithDefaultOptions, typeof(DOrtGetAllocatorWithDefaultOptions)); + OrtCreateMemoryInfoV2 = (DOrtCreateMemoryInfoV2)Marshal.GetDelegateForFunctionPointer(api_.CreateMemoryInfo_V2, typeof(DOrtCreateMemoryInfoV2)); + OrtMemoryInfoGetDeviceMemType = (DOrtMemoryInfoGetDeviceMemType)Marshal.GetDelegateForFunctionPointer(api_.MemoryInfoGetDeviceMemType, typeof(DOrtMemoryInfoGetDeviceMemType)); + OrtMemoryInfoGetVendorId = (DOrtMemoryInfoGetVendorId)Marshal.GetDelegateForFunctionPointer(api_.MemoryInfoGetVendorId, typeof(DOrtMemoryInfoGetVendorId)); OrtCreateAllocator = (DOrtCreateAllocator)Marshal.GetDelegateForFunctionPointer(api_.CreateAllocator, typeof(DOrtCreateAllocator)); OrtReleaseAllocator = (DOrtReleaseAllocator)Marshal.GetDelegateForFunctionPointer(api_.ReleaseAllocator, typeof(DOrtReleaseAllocator)); OrtAllocatorAlloc = (DOrtAllocatorAlloc)Marshal.GetDelegateForFunctionPointer(api_.AllocatorAlloc, typeof(DOrtAllocatorAlloc)); @@ -610,6 +616,7 @@ static NativeMethods() OrtTensorAt = (DOrtTensorAt)Marshal.GetDelegateForFunctionPointer(api_.TensorAt, typeof(DOrtTensorAt)); OrtCreateAndRegisterAllocator = (DOrtCreateAndRegisterAllocator)Marshal.GetDelegateForFunctionPointer(api_.CreateAndRegisterAllocator, typeof(DOrtCreateAndRegisterAllocator)); + OrtUnregisterAllocator = (DOrtUnregisterAllocator)Marshal.GetDelegateForFunctionPointer(api_.UnregisterAllocator, typeof(DOrtUnregisterAllocator)); OrtSetLanguageProjection = (DOrtSetLanguageProjection)Marshal.GetDelegateForFunctionPointer(api_.SetLanguageProjection, typeof(DOrtSetLanguageProjection)); OrtHasValue = (DOrtHasValue)Marshal.GetDelegateForFunctionPointer(api_.HasValue, typeof(DOrtHasValue)); @@ -696,11 +703,11 @@ static NativeMethods() OrtReleaseROCMProviderOptions = (DOrtReleaseROCMProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.ReleaseROCMProviderOptions, typeof(DOrtReleaseROCMProviderOptions)); OrtCreateAndRegisterAllocatorV2 = (DCreateAndRegisterAllocatorV2)Marshal.GetDelegateForFunctionPointer(api_.CreateAndRegisterAllocatorV2, typeof(DCreateAndRegisterAllocatorV2)); OrtRunAsync = (DOrtRunAsync)Marshal.GetDelegateForFunctionPointer(api_.RunAsync, typeof(DOrtRunAsync)); - CreateLoraAdapter = (DCreateLoraAdapter)Marshal.GetDelegateForFunctionPointer(api_.CreateLoraAdapter, - typeof(DCreateLoraAdapter)); - CreateLoraAdapterFromArray = (DCreateLoraAdapterFromArray)Marshal.GetDelegateForFunctionPointer (api_.CreateLoraAdapterFromArray, typeof(DCreateLoraAdapterFromArray)); - ReleaseLoraAdapter = (DReleaseLoraAdapter)Marshal.GetDelegateForFunctionPointer(api_.ReleaseLoraAdapter, - typeof(DReleaseLoraAdapter)); + OrtCreateLoraAdapter = (DOrtCreateLoraAdapter)Marshal.GetDelegateForFunctionPointer(api_.CreateLoraAdapter, + typeof(DOrtCreateLoraAdapter)); + OrtCreateLoraAdapterFromArray = (DOrtCreateLoraAdapterFromArray)Marshal.GetDelegateForFunctionPointer(api_.CreateLoraAdapterFromArray, typeof(DOrtCreateLoraAdapterFromArray)); + OrtReleaseLoraAdapter = (DOrtReleaseLoraAdapter)Marshal.GetDelegateForFunctionPointer(api_.ReleaseLoraAdapter, + typeof(DOrtReleaseLoraAdapter)); OrtRunOptionsAddActiveLoraAdapter = (DOrtRunOptionsAddActiveLoraAdapter)Marshal.GetDelegateForFunctionPointer( api_.RunOptionsAddActiveLoraAdapter, typeof(DOrtRunOptionsAddActiveLoraAdapter)); @@ -759,12 +766,15 @@ static NativeMethods() OrtEpDevice_Device = (DOrtEpDevice_Device)Marshal.GetDelegateForFunctionPointer( api_.EpDevice_Device, typeof(DOrtEpDevice_Device)); - OrtRegisterExecutionProviderLibrary = + OrtEpDevice_MemoryInfo = (DOrtEpDevice_MemoryInfo)Marshal.GetDelegateForFunctionPointer( + api_.EpDevice_MemoryInfo, typeof(DOrtEpDevice_MemoryInfo)); + + OrtRegisterExecutionProviderLibrary = (DOrtRegisterExecutionProviderLibrary)Marshal.GetDelegateForFunctionPointer( api_.RegisterExecutionProviderLibrary, typeof(DOrtRegisterExecutionProviderLibrary)); - OrtUnregisterExecutionProviderLibrary = + OrtUnregisterExecutionProviderLibrary = (DOrtUnregisterExecutionProviderLibrary)Marshal.GetDelegateForFunctionPointer( api_.UnregisterExecutionProviderLibrary, typeof(DOrtUnregisterExecutionProviderLibrary)); @@ -773,12 +783,12 @@ static NativeMethods() api_.GetEpDevices, typeof(DOrtGetEpDevices)); - OrtSessionOptionsAppendExecutionProvider_V2 = + OrtSessionOptionsAppendExecutionProvider_V2 = (DOrtSessionOptionsAppendExecutionProvider_V2)Marshal.GetDelegateForFunctionPointer( api_.SessionOptionsAppendExecutionProvider_V2, typeof(DOrtSessionOptionsAppendExecutionProvider_V2)); - OrtSessionOptionsSetEpSelectionPolicy = + OrtSessionOptionsSetEpSelectionPolicy = (DSessionOptionsSetEpSelectionPolicy)Marshal.GetDelegateForFunctionPointer( api_.SessionOptionsSetEpSelectionPolicy, typeof(DSessionOptionsSetEpSelectionPolicy)); @@ -817,6 +827,40 @@ static NativeMethods() api_.CreateExternalInitializerInfo, typeof(DOrtCreateExternalInitializerInfo)); + OrtCreateSharedAllocator = + (DOrtCreateSharedAllocator)Marshal.GetDelegateForFunctionPointer( + api_.CreateSharedAllocator, + typeof(DOrtCreateSharedAllocator)); + + OrtGetSharedAllocator = + (DOrtGetSharedAllocator)Marshal.GetDelegateForFunctionPointer( + api_.GetSharedAllocator, + typeof(DOrtGetSharedAllocator)); + + OrtReleaseSharedAllocator = + (DOrtReleaseSharedAllocator)Marshal.GetDelegateForFunctionPointer( + api_.ReleaseSharedAllocator, + typeof(DOrtReleaseSharedAllocator)); + + OrtCreateSyncStreamForEpDevice = + (DOrtCreateSyncStreamForEpDevice)Marshal.GetDelegateForFunctionPointer( + api_.CreateSyncStreamForEpDevice, + typeof(DOrtCreateSyncStreamForEpDevice)); + + OrtSyncStream_GetHandle = + (DOrtSyncStream_GetHandle)Marshal.GetDelegateForFunctionPointer( + api_.SyncStream_GetHandle, + typeof(DOrtSyncStream_GetHandle)); + + OrtReleaseSyncStream = + (DOrtReleaseSyncStream)Marshal.GetDelegateForFunctionPointer( + api_.ReleaseSyncStream, + typeof(DOrtReleaseSyncStream)); + + OrtCopyTensors = + (DOrtCopyTensors)Marshal.GetDelegateForFunctionPointer( + api_.CopyTensors, + typeof(DOrtCopyTensors)); } internal class NativeLib @@ -839,7 +883,7 @@ internal class NativeLib public static extern ref OrtApiBase OrtGetApiBase(); #endif -#region Runtime / Environment API + #region Runtime / Environment API [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* OrtStatus* */ DOrtCreateEnv( @@ -896,9 +940,59 @@ internal class NativeLib public delegate IntPtr /* OrtStatus* */ DOrtUpdateEnvWithCustomLogLevel(IntPtr /*(OrtEnv*)*/ env, OrtLoggingLevel custom_log_level); public static DOrtUpdateEnvWithCustomLogLevel OrtUpdateEnvWithCustomLogLevel; -#endregion Runtime / Environment API + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DCreateAndRegisterAllocatorV2( + IntPtr /* OrtEnv* */ environment, + IntPtr /*const char* */ provderType, + IntPtr /* const OrtMemoryInfo* */ memInfo, + IntPtr /* const OrtArenaCfg* */ arenaCfg, + IntPtr[] /* const char* const* */ providerOptionsKeys, + IntPtr[] /* const char* const* */ providerOptionsValues, + UIntPtr /* size_t */ numKeys); + public static DCreateAndRegisterAllocatorV2 OrtCreateAndRegisterAllocatorV2; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtCreateSharedAllocator( + IntPtr /* OrtEnv* */ ortEnv, + IntPtr /* OrtEpDevice* */ epDevice, + OrtDeviceMemoryType deviceMemoryType, + OrtAllocatorType allocatorType, + IntPtr /* const OrtKeyValuePairs* */ allocatorOptions, + out IntPtr /* OrtAllocator** */ allocator); + + public static DOrtCreateSharedAllocator OrtCreateSharedAllocator; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtGetSharedAllocator( + IntPtr /*(OrtEnv*)*/ env, + IntPtr /*(const OrtMemoryInfo*)*/ memInfo, + out IntPtr /* OrtAllocator** */ allocator); + + public static DOrtGetSharedAllocator OrtGetSharedAllocator; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtReleaseSharedAllocator( + IntPtr /*(OrtEnv*)*/ env, + IntPtr /* const OrtEpDevice* */ epDevice, + OrtDeviceMemoryType deviceMemoryType); + + public static DOrtReleaseSharedAllocator OrtReleaseSharedAllocator; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtCopyTensors( + IntPtr /* const OrtEnv* */ env, + IntPtr[] /* const OrtValue* const* */ srcTensors, + IntPtr[] /* OrtValue* const* */ dstTensors, + IntPtr /* OrtSynStream* */ stream, + UIntPtr /* size_t */ numTensors + ); + + public static DOrtCopyTensors OrtCopyTensors; -#region Provider Options API + + #endregion Runtime / Environment API + + #region Provider Options API /// /// Creates native OrtTensorRTProviderOptions instance @@ -1032,9 +1126,9 @@ internal class NativeLib public delegate void DOrtReleaseROCMProviderOptions(IntPtr /*(OrtROCMProviderOptions*)*/ rocmProviderOptionsInstance); public static DOrtReleaseROCMProviderOptions OrtReleaseROCMProviderOptions; -#endregion + #endregion -#region Status API + #region Status API [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate ErrorCode DOrtGetErrorCode(IntPtr /*(OrtStatus*)*/ status); public static DOrtGetErrorCode OrtGetErrorCode; @@ -1049,12 +1143,12 @@ internal class NativeLib public delegate void DOrtReleaseStatus(IntPtr /*(OrtStatus*)*/ statusPtr); public static DOrtReleaseStatus OrtReleaseStatus; -#endregion Status API + #endregion Status API -#region InferenceSession API + #region InferenceSession API [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* OrtStatus* */ DOrtCreateStatus( - uint /* OrtErrorCode */ code, + uint /* OrtErrorCode */ code, byte[] /* const char* */ msg); public static DOrtCreateStatus OrtCreateStatus; @@ -1216,6 +1310,30 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca public static DOrtSessionGetOverridableInitializerTypeInfo OrtSessionGetOverridableInitializerTypeInfo; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtSessionGetMemoryInfoForInputs( + IntPtr /*(const OrtSession*)*/ session, + IntPtr[] /* const OrtMemoryInfo** */ inputsMemoryInfos, + UIntPtr /* size_t */ numInputs); + + public static DOrtSessionGetMemoryInfoForInputs OrtSessionGetMemoryInfoForInputs; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtSessionGetMemoryInfoForOutputs( + IntPtr /*(const OrtSession*)*/ session, + IntPtr[] /* OrtMemoryInfo** */ outputsMemoryInfos, + UIntPtr /* size_t */ numOutputs); + + public static DOrtSessionGetMemoryInfoForOutputs OrtSessionGetMemoryInfoForOutputs; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtSessionGetEpDeviceForInputs( + IntPtr /*(const OrtSession*)*/ session, + IntPtr[] /* const OrtDevice** */ devices, + UIntPtr /* size_t */ numInputs); + + public static DOrtSessionGetEpDeviceForInputs OrtSessionGetEpDeviceForInputs; + // release the typeinfo using OrtReleaseTypeInfo [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate void DOrtReleaseTypeInfo(IntPtr /*(OrtTypeInfo*)*/ session); @@ -1231,17 +1349,6 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca out UIntPtr /*(ulong* out)*/ startTime); public static DOrtSessionGetProfilingStartTimeNs OrtSessionGetProfilingStartTimeNs; - [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(ONNStatus*)*/ DCreateAndRegisterAllocatorV2( - IntPtr /* (OrtEnv*) */ environment, - IntPtr /*(char*)*/ provider_type, - IntPtr /*(OrtMemoryInfo*)*/ mem_info, - IntPtr /*(OrtArenaCfg*)*/ arena_cfg, - IntPtr /*(char**)*/ provider_options_keys, - IntPtr /*(char**)*/ provider_options_values, - UIntPtr /*(size_t)*/ num_keys); - public static DCreateAndRegisterAllocatorV2 OrtCreateAndRegisterAllocatorV2; - [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(ONNStatus*)*/ DOrtRunAsync( IntPtr /*(OrtSession*)*/ session, @@ -1256,9 +1363,9 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca IntPtr /*(void*)*/ user_data); public static DOrtRunAsync OrtRunAsync; -#endregion InferenceSession API + #endregion InferenceSession API -#region SessionOptions API + #region SessionOptions API [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtCreateSessionOptions(out IntPtr /*(OrtSessionOptions**)*/ sessionOptions); @@ -1546,9 +1653,9 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca public static DSessionOptionsAppendExecutionProvider SessionOptionsAppendExecutionProvider; -#endregion + #endregion -#region LoraAdapter API + #region LoraAdapter API /// /// Memory maps the adapter file, wraps it into the adapter object /// and returns it. @@ -1558,12 +1665,12 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// New LoraAdapter object /// [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/ DCreateLoraAdapter( + public delegate IntPtr /*(OrtStatus*)*/ DOrtCreateLoraAdapter( byte[] adapter_path, // This takes const ORTCHAR_T* use GetPlatformSerializedString IntPtr /* OrtAllocator */ allocator, // optional out IntPtr lora_adapter ); - public static DCreateLoraAdapter CreateLoraAdapter; + public static DOrtCreateLoraAdapter OrtCreateLoraAdapter; /// /// Creates LoraAdapter instance from a byte array that must @@ -1575,22 +1682,22 @@ out IntPtr lora_adapter /// resulting LoraAdapter instance /// [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/ DCreateLoraAdapterFromArray( + public delegate IntPtr /*(OrtStatus*)*/ DOrtCreateLoraAdapterFromArray( byte[] bytes, UIntPtr size, IntPtr /* OrtAllocator */ allocator, // optional out IntPtr lora_adapter ); - public static DCreateLoraAdapterFromArray CreateLoraAdapterFromArray; + public static DOrtCreateLoraAdapterFromArray OrtCreateLoraAdapterFromArray; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate void DReleaseLoraAdapter(IntPtr /* OrtLoraAdapter* */ lora_adapter); - public static DReleaseLoraAdapter ReleaseLoraAdapter; + public delegate void DOrtReleaseLoraAdapter(IntPtr /* OrtLoraAdapter* */ lora_adapter); + public static DOrtReleaseLoraAdapter OrtReleaseLoraAdapter; -#endregion + #endregion -#region RunOptions API + #region RunOptions API [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtCreateRunOptions(out IntPtr /* OrtRunOptions** */ runOptions); @@ -1653,9 +1760,9 @@ out IntPtr lora_adapter byte[] /* const char* */ configValue); public static DOrtAddRunConfigEntry OrtAddRunConfigEntry; -#endregion + #endregion -#region ThreadingOptions API + #region ThreadingOptions API [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtCreateThreadingOptions(out IntPtr /* OrtCreateThreadingOptions** */ threadingOptions); @@ -1680,9 +1787,9 @@ out IntPtr lora_adapter [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtThreadingOptionsSetGlobalSpinControl(IntPtr /* OrtThreadingOptions* */ threadingOptions, int allowSpinning); public static DOrtThreadingOptionsSetGlobalSpinControl OrtThreadingOptionsSetGlobalSpinControl; -#endregion + #endregion -#region Allocator / MemoryInfo API + #region Allocator / MemoryInfo API [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* (OrtStatus*)*/ DOrtCreateMemoryInfo( @@ -1695,6 +1802,20 @@ out IntPtr lora_adapter public static DOrtCreateMemoryInfo OrtCreateMemoryInfo; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* (OrtStatus*)*/ DOrtCreateMemoryInfoV2( + byte[] /*(const char*) */ name, + OrtMemoryInfoDeviceType memInfoDeviceType, + UInt32 /* uint32_t */ vendorId, + Int32 /* int32_t */ deviceId, + OrtDeviceMemoryType deviceMemoryType, + UIntPtr /* size_t */ alignment, + OrtAllocatorType allocatorType, + out IntPtr /*(OrtMemoryInfo*)*/ allocatorInfo // memory ownership transferred to caller + ); + + public static DOrtCreateMemoryInfoV2 OrtCreateMemoryInfoV2; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* (OrtStatus*)*/ DOrtCreateCpuMemoryInfo( OrtAllocatorType allocatorType, @@ -1743,6 +1864,18 @@ out IntPtr lora_adapter public static DOrtMemoryInfoGetType OrtMemoryInfoGetType; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate OrtDeviceMemoryType DOrtMemoryInfoGetDeviceMemType( + IntPtr /*(const OrtMemoryInfo* ptr)*/ memoryInfo); + + public static DOrtMemoryInfoGetDeviceMemType OrtMemoryInfoGetDeviceMemType; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate UInt32 DOrtMemoryInfoGetVendorId( + IntPtr /*(const OrtMemoryInfo* ptr)*/ memoryInfo); + + public static DOrtMemoryInfoGetVendorId OrtMemoryInfoGetVendorId; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtGetAllocatorWithDefaultOptions(out IntPtr /*(OrtAllocator**)*/ allocator); @@ -1819,9 +1952,9 @@ out IntPtr lora_adapter public static DOrtAllocatorFree OrtAllocatorFree; -#endregion Allocator / MemoryInfo API + #endregion Allocator / MemoryInfo API -#region IoBinding API + #region IoBinding API /// /// Create OrtIoBinding instance that is used to bind memory that is allocated @@ -1985,7 +2118,8 @@ out IntPtr lora_adapter /// /// Creates an allocator instance and registers it with the env to enable /// sharing between multiple sessions that use the same env instance. - /// Lifetime of the created allocator will be valid for the duration of the environment. + /// Lifetime of the created allocator will be valid for the duration of the environment + /// or until it is explicitly unregistered by UnregisterAllocator. /// Returns an error if an allocator with the same OrtMemoryInfo is already registered. /// /// Native OrtEnv instance @@ -1999,6 +2133,20 @@ out IntPtr lora_adapter public static DOrtCreateAndRegisterAllocator OrtCreateAndRegisterAllocator; + + /// + /// Unregisters an allocator that was previously registered with the env using + /// or . + /// + /// valid env + /// meminfo used for registering the allocator + /// + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtUnregisterAllocator(IntPtr /*(OrtEnv*)*/ env, + IntPtr /*(const OrtMemoryInfo*)*/ memInfo); + + public static DOrtUnregisterAllocator OrtUnregisterAllocator; + /// /// Set the language projection for collecting telemetry data when Env is created /// @@ -2009,9 +2157,9 @@ out IntPtr lora_adapter public static DOrtSetLanguageProjection OrtSetLanguageProjection; -#endregion IoBinding API + #endregion IoBinding API -#region ModelMetadata API + #region ModelMetadata API /// /// Gets the ModelMetadata associated with an InferenceSession @@ -2129,9 +2277,9 @@ out IntPtr lora_adapter public static DOrtReleaseModelMetadata OrtReleaseModelMetadata; -#endregion ModelMetadata API + #endregion ModelMetadata API -#region OrtValue API + #region OrtValue API [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtHasValue(IntPtr /*(OrtValue*)*/ value, out IntPtr /*(int*)*/ hasValue); @@ -2397,9 +2545,9 @@ out IntPtr lora_adapter public static DOrtReleaseValue OrtReleaseValue; -#endregion + #endregion -#region Compile API + #region Compile API #if NETSTANDARD2_0 [UnmanagedFunctionPointer(CallingConvention.Winapi)] @@ -2473,9 +2621,10 @@ out IntPtr /* OrtExternalInitializerInfo** */ newExternalInfo public static DOrtExternalInitializerInfo_GetFilePath OrtExternalInitializerInfo_GetFilePath; public static DOrtExternalInitializerInfo_GetFileOffset OrtExternalInitializerInfo_GetFileOffset; public static DOrtExternalInitializerInfo_GetByteSize OrtExternalInitializerInfo_GetByteSize; -#endregion -#region Auto EP API related + #endregion + + #region Auto EP API related // // OrtKeyValuePairs @@ -2582,12 +2731,36 @@ public delegate void DOrtRemoveKeyValuePair(IntPtr /* OrtKeyValuePairs* */ kvps, public delegate IntPtr /* const OrtHardwareDevice* */ DOrtEpDevice_Device( IntPtr /* const OrtEpDevice* */ ep_device); + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* const OrtMemoryInfo* */ DOrtEpDevice_MemoryInfo( + IntPtr /* const OrtEpDevice* */ ep_device, OrtDeviceMemoryType deviceMemoryType); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtCreateSyncStreamForEpDevice( + IntPtr /* const OrtEpDevice* */ epDevice, + IntPtr /* const OrtKeyValuePairs* */ streamOptions, + out IntPtr /* OrtSyncStream** */ stream + ); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* void* */ DOrtSyncStream_GetHandle( + IntPtr /* OrtSyncStream* */ stream + ); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtReleaseSyncStream( + IntPtr /* OrtSyncStream* */ stream + ); public static DOrtEpDevice_EpName OrtEpDevice_EpName; public static DOrtEpDevice_EpVendor OrtEpDevice_EpVendor; public static DOrtEpDevice_EpMetadata OrtEpDevice_EpMetadata; public static DOrtEpDevice_EpOptions OrtEpDevice_EpOptions; public static DOrtEpDevice_Device OrtEpDevice_Device; + public static DOrtEpDevice_MemoryInfo OrtEpDevice_MemoryInfo; + public static DOrtCreateSyncStreamForEpDevice OrtCreateSyncStreamForEpDevice; + public static DOrtSyncStream_GetHandle OrtSyncStream_GetHandle; + public static DOrtReleaseSyncStream OrtReleaseSyncStream; // // Auto Selection EP registration and selection customization @@ -2763,7 +2936,7 @@ public delegate IntPtr DOrtEpSelectionDelegate( public static DOrtReleasePrepackedWeightsContainer OrtReleasePrepackedWeightsContainer; -#endregion + #endregion } // class NativeMethods // onnxruntime-extensions helpers to make usage simpler. diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs index 4611428ea12ef..f5dc253195ab1 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs @@ -373,23 +373,21 @@ internal static void Update(Dictionary providerOptions, IntPtr handle, Func updateFunc) { - var keyStrings = providerOptions.Keys.ToArray(); - var valStrings = providerOptions.Values.ToArray(); - MarshaledStringArray keys = default; MarshaledStringArray values = default; try { - keys = new MarshaledStringArray(keyStrings); - values = new MarshaledStringArray(valStrings); + keys = new MarshaledStringArray(providerOptions.Keys); + values = new MarshaledStringArray(providerOptions.Values); - var nativeKeys = new IntPtr[keyStrings.Length]; + var nativeKeys = new IntPtr[providerOptions.Count]; keys.Fill(nativeKeys); - var nativeVals = new IntPtr[valStrings.Length]; + var nativeVals = new IntPtr[providerOptions.Count]; values.Fill(nativeVals); - NativeApiStatus.VerifySuccess(updateFunc(handle, nativeKeys, nativeVals, (UIntPtr)providerOptions.Count)); + NativeApiStatus.VerifySuccess(updateFunc(handle, nativeKeys, nativeVals, + (UIntPtr)providerOptions.Count)); } finally { diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtAllocator.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtAllocator.shared.cs index 3f918fc2ad6c8..c189cc1856252 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtAllocator.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtAllocator.shared.cs @@ -3,6 +3,7 @@ using Microsoft.ML.OnnxRuntime.Tensors; using System; +using System.Reflection; using System.Runtime.InteropServices; using System.Text; @@ -28,6 +29,28 @@ public enum OrtMemType Default = 0, // the default allocator for execution provider } + /// + /// See documentation for OrtDeviceMemoryType in C API + /// This matches OrtDevice::MemoryType values + /// + public enum OrtDeviceMemoryType + { + DEFAULT = 0, /// Device memory + HOST_ACCESSIBLE = 5, /// Shared/pinned memory for transferring between CPU and the device + } + + /// + /// See documentation for OrtMemoryInfoDeviceType in C API + /// This mimics OrtDevice type constants so they can be returned in the API + /// + public enum OrtMemoryInfoDeviceType + { + CPU = 0, + GPU = 1, + FPGA = 2, + NPU = 3, + } + /// /// This class encapsulates arena configuration information that will be used to define the behavior /// of an arena based allocator @@ -103,7 +126,8 @@ public class OrtMemoryInfo : SafeHandle private static OrtMemoryInfo CreateCpuMemoryInfo() { // Returns OrtMemoryInfo instance that needs to be disposed - NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateCpuMemoryInfo(OrtAllocatorType.DeviceAllocator, OrtMemType.Cpu, out IntPtr memoryInfo)); + NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateCpuMemoryInfo(OrtAllocatorType.DeviceAllocator, + OrtMemType.Cpu, out IntPtr memoryInfo)); return new OrtMemoryInfo(memoryInfo, true); } @@ -203,6 +227,26 @@ public OrtMemoryInfo(byte[] utf8AllocatorName, OrtAllocatorType allocatorType, i public OrtMemoryInfo(string allocatorName, OrtAllocatorType allocatorType, int deviceId, OrtMemType memoryType) : this(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(allocatorName), allocatorType, deviceId, memoryType) { + + } + + /// + /// Creates an instance of OrtMemoryInfo using OrtCreateMemoryInfoV2 + /// + /// In this overload this is an arbitrary name + /// Device Type + /// Vendor Id + /// Device Id + /// Device Memory Type + /// Alignment is required or 0 + /// Allocator Type + public OrtMemoryInfo(string allocatorName, OrtMemoryInfoDeviceType deviceType, uint vendorId, + int deviceId, OrtDeviceMemoryType deviceMemoryType, ulong alignment, OrtAllocatorType allocatorType) + : base(IntPtr.Zero, true) + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateMemoryInfoV2( + NativeOnnxValueHelper.StringToZeroTerminatedUtf8(allocatorName), + deviceType, vendorId, deviceId, deviceMemoryType, (UIntPtr)alignment, allocatorType, out handle)); } /// @@ -252,6 +296,24 @@ public OrtAllocatorType GetAllocatorType() return allocatorType; } + /// + /// Return the device memory type associated with this memory info + /// + /// OrtDeviceMemoryType for the device + public OrtDeviceMemoryType GetDeviceMemoryType() + { + return NativeMethods.OrtMemoryInfoGetDeviceMemType(handle); + } + + /// + /// Fetches vendor ID + /// + /// uint32_t + public uint GetVendorId() + { + return NativeMethods.OrtMemoryInfoGetVendorId(handle); + } + /// /// Overrides System.Object.Equals(object) /// @@ -493,12 +555,6 @@ internal IntPtr Pointer } } - /// - /// Overrides SafeHandle.IsInvalid - /// - /// returns true if handle is equal to Zero - public override bool IsInvalid { get { return handle == IntPtr.Zero; } } - /// /// Internal constructor wraps existing native allocators /// @@ -560,6 +616,14 @@ internal void FreeMemory(IntPtr allocation) } #region SafeHandle + + /// + /// Overrides SafeHandle.IsInvalid + /// + /// returns true if handle is equal to Zero + public override bool IsInvalid { get { return handle == IntPtr.Zero; } } + + /// /// Overrides SafeHandle.ReleaseHandle() to properly dispose of /// the native instance of OrtAllocator diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs index 052d5899b52c0..6fcff438c5cf3 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs @@ -329,14 +329,115 @@ public void DisableTelemetryEvents() } /// - /// Create and register an allocator to the OrtEnv instance - /// so as to enable sharing across all sessions using the OrtEnv instance + /// Create and register an allocator to the OrtEnv instance. + /// This API enhance CreateAndRegisterAllocator that it can create an allocator with specific type, not just CPU allocator + /// Enables sharing the allocator between multiple sessions that use the same env instance. + /// Lifetime of the created allocator will be valid for the duration of the environment. + /// so as to enable sharing across all sessions using the OrtEnv instance. /// OrtMemoryInfo instance to be used for allocator creation /// OrtArenaCfg instance that will be used to define the behavior of the arena based allocator /// public void CreateAndRegisterAllocator(OrtMemoryInfo memInfo, OrtArenaCfg arenaCfg) { - NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateAndRegisterAllocator(Handle, memInfo.Pointer, arenaCfg.Pointer)); + NativeApiStatus.VerifySuccess( + NativeMethods.OrtCreateAndRegisterAllocator(Handle, memInfo.Pointer, arenaCfg.Pointer)); + } + + /// + /// Create and register an allocator to the OrtEnv instance. + /// Use UnregisterAllocator to unregister it. + /// + /// + /// + /// + /// + public void CreateAndRegisterAllocator(string providerType, OrtMemoryInfo memInfo, OrtArenaCfg arenaCfg, + IReadOnlyDictionary provider_options) + { + MarshaledStringArray marshalledKeys = default; + MarshaledStringArray marshalledValues = default; + var keysPtrs = new IntPtr[provider_options.Count]; + var valuesPtrs = new IntPtr[provider_options.Count]; + + try + { + marshalledKeys = new MarshaledStringArray(provider_options.Keys); + marshalledValues = new MarshaledStringArray(provider_options.Values); + marshalledKeys.Fill(keysPtrs); + marshalledValues.Fill(valuesPtrs); + using var marshalledProviderType = new MarshaledString(providerType); + + NativeApiStatus.VerifySuccess( + NativeMethods.OrtCreateAndRegisterAllocatorV2(Handle, marshalledProviderType.Value, + memInfo.Pointer, arenaCfg.Pointer, + keysPtrs, valuesPtrs, + (UIntPtr)provider_options.Count)); + } + finally + { + marshalledValues.Dispose(); + marshalledKeys.Dispose(); + } + } + + /// + /// Unregister a custom allocator previously registered with the OrtEnv instance + /// using CreateAndRegisterAllocator + /// The memory info instance should correspond the one that is used for registration + /// + /// The memory info instance should correspond the one that is used for registration + public void UnregisterAllocator(OrtMemoryInfo memInfo) + { + NativeApiStatus.VerifySuccess( + NativeMethods.OrtUnregisterAllocator(Handle, memInfo.Pointer)); + } + + /// + /// Creates shared allocator owned by the OrtEnv instance. + /// + /// + /// + /// + /// allocator specific options + /// OrtAllocator instance + public OrtAllocator CreateSharedAllocator(OrtEpDevice epDevice, OrtDeviceMemoryType deviceMemoryType, + OrtAllocatorType ortAllocatorType, IReadOnlyDictionary allocatorOptions) + { + using var keyValueOptions = new OrtKeyValuePairs(allocatorOptions); + NativeApiStatus.VerifySuccess( + NativeMethods.OrtCreateSharedAllocator(Handle, epDevice.Handle, deviceMemoryType, + ortAllocatorType, keyValueOptions.Handle, out IntPtr allocatorHandle)); + return new OrtAllocator(allocatorHandle, /* owned= */ false); + } + + /// + /// Returns a shared allocator owned by the OrtEnv instance if such exists + /// (was previously created). If no such allocator exists, the API returns null. + /// + /// + /// OrtAllocator instance or null if the requested allocator does not exist + public OrtAllocator GetSharedAllocator(OrtMemoryInfo memoryInfo) + { + NativeApiStatus.VerifySuccess( + NativeMethods.OrtGetSharedAllocator(Handle, memoryInfo.Pointer, out IntPtr allocatorHandle)); + if (allocatorHandle == IntPtr.Zero) + { + return null; + } + return new OrtAllocator(allocatorHandle, /* owned= */ false); + } + + /// + /// Release a shared allocator from the OrtEnv for the OrtEpDevice and memory type. + /// This will release the shared allocator for the given OrtEpDevice and memory type. + /// If no shared allocator exists, this is a no-op. + /// + /// + /// + public void ReleaseSharedAllocator(OrtEpDevice epDevice, OrtDeviceMemoryType deviceMemoryType) + { + NativeApiStatus.VerifySuccess( + NativeMethods.OrtReleaseSharedAllocator(Handle, epDevice.Handle, deviceMemoryType)); } /// @@ -477,7 +578,37 @@ public IReadOnlyList GetEpDevices() } return epDevices.AsReadOnly(); - } + } + + /// + /// Copies data from source OrtValue tensors to destination OrtValue tensors. + /// The tensors may reside on difference devices if such are supported + /// by the registered execution providers. + /// + /// Source OrtValues + /// pre-allocated OrtValues + /// optional stream or null + /// + public void CopyTensors(IReadOnlyList srcValues, IReadOnlyList dstValues, + OrtSyncStream stream) + { + IntPtr streamHandle = stream != null ? stream.Handle : IntPtr.Zero; + IntPtr[] srcPtrs = new IntPtr[srcValues.Count]; + IntPtr[] dstPtrs = new IntPtr[dstValues.Count]; + + for (int i = 0; i < srcPtrs.Length; i++) + { + if (srcValues[i] == null) + throw new ArgumentNullException($"srcValues[{i}]"); + if (dstValues[i] == null) + throw new ArgumentNullException($"dstValues[{i}]"); + srcPtrs[i] = srcValues[i].Handle; + dstPtrs[i] = dstValues[i].Handle; + } + + NativeApiStatus.VerifySuccess( + NativeMethods.OrtCopyTensors(handle, srcPtrs, dstPtrs, streamHandle, (UIntPtr)srcPtrs.Length)); + } #endregion diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtEpDevice.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEpDevice.shared.cs index 0318e08519128..9e59754374464 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtEpDevice.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEpDevice.shared.cs @@ -2,10 +2,51 @@ // Licensed under the MIT License. using System; +using System.Collections.Generic; using System.Runtime.InteropServices; namespace Microsoft.ML.OnnxRuntime { + /// + /// Represents a synchronization primitive for stream operations. + /// + public class OrtSyncStream : SafeHandle + { + internal OrtSyncStream(IntPtr streamHandle) + : base(IntPtr.Zero, true) // Provide required arguments to SafeHandle constructor + { + handle = streamHandle; + } + + /// + /// Fetch sync stream handle for possible use + /// in session options. + /// + /// Opaque stream handle + public IntPtr GetHandle() + { + return NativeMethods.OrtSyncStream_GetHandle(handle); + } + + internal IntPtr Handle => handle; + + /// + /// Implements SafeHandle interface + /// + public override bool IsInvalid => handle == IntPtr.Zero; + + /// + /// Implements SafeHandle interface to release native handle + /// + /// always true + protected override bool ReleaseHandle() + { + NativeMethods.OrtReleaseSyncStream(handle); + handle = IntPtr.Zero; + return true; + } + } + /// /// Represents the combination of an execution provider and a hardware device /// that the execution provider can utilize. @@ -81,6 +122,46 @@ public OrtHardwareDevice HardwareDevice } } + /// + /// The OrtMemoryInfo instance describing the memory characteristics of the device. + /// + /// memory type requested + /// + public OrtMemoryInfo GetMemoryInfo(OrtDeviceMemoryType deviceMemoryType) + { + IntPtr memoryInfoPtr = NativeMethods.OrtEpDevice_MemoryInfo(_handle, deviceMemoryType); + return new OrtMemoryInfo(memoryInfoPtr, /* owned= */ false); + } + + /// + /// Creates a synchronization stream for operations on this device. + /// Can be used to implement async operations on the device such as + /// CopyTensors. + /// + /// stream options can be null + /// + public OrtSyncStream CreateSyncStream(IReadOnlyDictionary streamOptions) + { + OrtKeyValuePairs options = null; + IntPtr optionsHandle = IntPtr.Zero; + try + { + if (streamOptions != null) + { + options = new OrtKeyValuePairs(streamOptions); + optionsHandle = options.Handle; + } + + NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSyncStreamForEpDevice(_handle, + optionsHandle, out IntPtr syncStream)); + return new OrtSyncStream(syncStream); + } + finally + { + options?.Dispose(); + } + } + private readonly IntPtr _handle; } } \ No newline at end of file diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtKeyValuePairs.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtKeyValuePairs.shared.cs index 6a8d1037d9017..50fd1965231e1 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtKeyValuePairs.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtKeyValuePairs.shared.cs @@ -169,6 +169,11 @@ private Dictionary GetLatest() return dict; } + /// + /// Native handle to the OrtKeyValuePairs instance. + /// + internal IntPtr Handle => handle; + /// /// Indicates whether the native handle is invalid. /// diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtLoraAdapter.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtLoraAdapter.shared.cs index e2249b4c47fec..f1c03faccf16f 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtLoraAdapter.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtLoraAdapter.shared.cs @@ -23,7 +23,7 @@ public static OrtLoraAdapter Create(string adapterPath, OrtAllocator ortAllocato { var platformPath = NativeOnnxValueHelper.GetPlatformSerializedString(adapterPath); var allocatorHandle = (ortAllocator != null) ? ortAllocator.Pointer : IntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeMethods.CreateLoraAdapter(platformPath, allocatorHandle, + NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateLoraAdapter(platformPath, allocatorHandle, out IntPtr adapterHandle)); return new OrtLoraAdapter(adapterHandle); } @@ -38,7 +38,7 @@ public static OrtLoraAdapter Create(string adapterPath, OrtAllocator ortAllocato public static OrtLoraAdapter Create(byte[] bytes, OrtAllocator ortAllocator) { var allocatorHandle = (ortAllocator != null) ? ortAllocator.Pointer : IntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeMethods.CreateLoraAdapterFromArray(bytes, + NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateLoraAdapterFromArray(bytes, new UIntPtr((uint)bytes.Length), allocatorHandle, out IntPtr adapterHandle)); return new OrtLoraAdapter(adapterHandle); } @@ -71,7 +71,7 @@ internal IntPtr Handle /// always returns true protected override bool ReleaseHandle() { - NativeMethods.ReleaseLoraAdapter(handle); + NativeMethods.OrtReleaseLoraAdapter(handle); handle = IntPtr.Zero; return true; } diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs index 0a39d965979ca..73613541f8362 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs @@ -4,13 +4,10 @@ using Microsoft.ML.OnnxRuntime.Tensors; using System; using System.Collections.Generic; +using System.IO; using System.Linq; -using System.Linq.Expressions; -using System.Runtime.CompilerServices; using System.Runtime.InteropServices; -using System.Text.RegularExpressions; -using System.Threading; using System.Threading.Tasks; using Xunit; using Xunit.Abstractions; @@ -837,7 +834,7 @@ private async Task TestMultiThreads() Assert.Equal(res, expectedOut, (IEqualityComparer)new FloatComparer()); } })); - }; + } await Task.WhenAll(tasks); session.Dispose(); } @@ -1694,37 +1691,52 @@ private void TestInferenceSessionWithByteArray() void TestCPUAllocatorInternal(InferenceSession session) + { int device_id = 0; - using (var info_cpu = new OrtMemoryInfo(OrtMemoryInfo.allocatorCPU, OrtAllocatorType.ArenaAllocator, device_id, OrtMemType.Default)) - { - Assert.Equal("Cpu", info_cpu.Name); - Assert.Equal(device_id, info_cpu.Id); - Assert.Equal(OrtAllocatorType.ArenaAllocator, info_cpu.GetAllocatorType()); - Assert.Equal(OrtMemType.Default, info_cpu.GetMemoryType()); - - using (var allocator = new OrtAllocator(session, info_cpu)) - { - var alloc_info = allocator.Info; - // Allocator type returned may be different on x86 so we don't compare. - Assert.Equal(info_cpu.Name, alloc_info.Name); - Assert.Equal(info_cpu.GetMemoryType(), alloc_info.GetMemoryType()); - Assert.Equal(info_cpu.Id, alloc_info.Id); - - uint size = 1024; - OrtMemoryAllocation chunk = allocator.Allocate(size); - Assert.Equal(chunk.Size, size); - var chunk_info = chunk.Info; - // Allocator type returned may be different on x86 so we don't compare. - Assert.Equal(chunk_info.Name, alloc_info.Name); - Assert.Equal(chunk_info.GetMemoryType(), alloc_info.GetMemoryType()); - Assert.Equal(chunk_info.Id, alloc_info.Id); - chunk.Dispose(); - alloc_info.Dispose(); - } - } + using var info_cpu = new OrtMemoryInfo(OrtMemoryInfo.allocatorCPU, + OrtAllocatorType.ArenaAllocator, device_id, OrtMemType.Default); + Assert.Equal("Cpu", info_cpu.Name); + Assert.Equal(device_id, info_cpu.Id); + Assert.Equal(OrtAllocatorType.ArenaAllocator, info_cpu.GetAllocatorType()); + Assert.Equal(OrtMemType.Default, info_cpu.GetMemoryType()); + var deviceMemoryType = info_cpu.GetDeviceMemoryType(); + Assert.Equal(OrtDeviceMemoryType.DEFAULT, deviceMemoryType); + Assert.Equal(0U, info_cpu.GetVendorId()); + + using var allocator = new OrtAllocator(session, info_cpu); + using var alloc_info = allocator.Info; + // Allocator type returned may be different on x86 so we don't compare. + Assert.Equal(info_cpu.Name, alloc_info.Name); + Assert.Equal(info_cpu.GetMemoryType(), alloc_info.GetMemoryType()); + Assert.Equal(info_cpu.Id, alloc_info.Id); + + uint size = 1024; + using OrtMemoryAllocation chunk = allocator.Allocate(size); + Assert.Equal(chunk.Size, size); + var chunk_info = chunk.Info; + // Allocator type returned may be different on x86 so we don't compare. + Assert.Equal(chunk_info.Name, alloc_info.Name); + Assert.Equal(chunk_info.GetMemoryType(), alloc_info.GetMemoryType()); + Assert.Equal(chunk_info.Id, alloc_info.Id); + } + + [Fact(DisplayName = "TestMemoryInfoCreateV2")] + void TestMemoryInfoCreateV2() + { + const int device_id = 0; + const uint vendor_id = 1234U; + using var info_cpu = new OrtMemoryInfo("Test_CPU", OrtMemoryInfoDeviceType.CPU, vendor_id, device_id, + OrtDeviceMemoryType.DEFAULT, 0, OrtAllocatorType.DeviceAllocator); + Assert.Equal("Test_CPU", info_cpu.Name); + Assert.Equal(device_id, info_cpu.Id); + Assert.Equal(OrtAllocatorType.DeviceAllocator, info_cpu.GetAllocatorType()); + Assert.Equal(OrtMemType.Default, info_cpu.GetMemoryType()); + Assert.Equal(OrtDeviceMemoryType.DEFAULT, info_cpu.GetDeviceMemoryType()); + Assert.Equal(vendor_id, info_cpu.GetVendorId()); } + #if USE_CUDA void TestCUDAAllocatorInternal(InferenceSession session) { @@ -1896,81 +1908,6 @@ private void TestSharingOfInitializerAndItsPrepackedVersion() } } - [Fact(DisplayName = "TestSharedAllocatorUsingCreateAndRegisterAllocator")] - private void TestSharedAllocatorUsingCreateAndRegisterAllocator() - { - var model = TestDataLoader.LoadModelFromEmbeddedResource("mul_1.onnx"); - - using (var memInfo = new OrtMemoryInfo(OrtMemoryInfo.allocatorCPU, - OrtAllocatorType.ArenaAllocator, 0, OrtMemType.Default)) - using (var arenaCfg = new OrtArenaCfg(0, -1, -1, -1)) - { - var env = OrtEnv.Instance(); - // Create and register the arena based allocator - env.CreateAndRegisterAllocator(memInfo, arenaCfg); - - using (var sessionOptions = new SessionOptions()) - { - // Key must match kOrtSessionOptionsConfigUseEnvAllocators in onnxruntime_session_options_config_keys.h - sessionOptions.AddSessionConfigEntry("session.use_env_allocators", "1"); - - // Create two sessions to share the allocator - // Create a third session that DOES NOT use the allocator in the environment - using (var session1 = new InferenceSession(model, sessionOptions)) - using (var session2 = new InferenceSession(model, sessionOptions)) - using (var session3 = new InferenceSession(model)) // Use the default SessionOptions instance - { - // Input data - var inputDims = new long[] { 3, 2 }; - var input = new float[] { 1.0F, 2.0F, 3.0F, 4.0F, 5.0F, 6.0F }; - - // Output data - int[] outputDims = { 3, 2 }; - float[] output = { 1.0F, 4.0F, 9.0F, 16.0F, 25.0F, 36.0F }; - - // Run inference on all three models - var inputMeta = session1.InputMetadata; - var container = new List(); - - foreach (var name in inputMeta.Keys) - { - Assert.Equal(typeof(float), inputMeta[name].ElementType); - Assert.True(inputMeta[name].IsTensor); - var tensor = new DenseTensor(input, inputMeta[name].Dimensions); - container.Add(NamedOnnxValue.CreateFromTensor(name, tensor)); - } - - // Run inference with named inputs and outputs created with in Run() - using (var results = session1.Run(container)) // results is an IReadOnlyList container - { - foreach (var r in results) - { - ValidateRunResultData(r.AsTensor(), output, outputDims); - } - } - - // Run inference with named inputs and outputs created with in Run() - using (var results = session2.Run(container)) // results is an IReadOnlyList container - { - foreach (var r in results) - { - ValidateRunResultData(r.AsTensor(), output, outputDims); - } - } - - // Run inference with named inputs and outputs created with in Run() - using (var results = session3.Run(container)) // results is an IReadOnlyList container - { - foreach (var r in results) - { - ValidateRunResultData(r.AsTensor(), output, outputDims); - } - } - } - } - } - } - internal static Tuple, float[]> OpenSessionSqueezeNet(int? deviceId = null) { var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtAutoEpTests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtAutoEpTests.cs index 9368f9d8bc298..1be0b6e9530ed 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtAutoEpTests.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtAutoEpTests.cs @@ -60,6 +60,8 @@ public void GetEpDevices() Assert.NotNull(metadata); var options = ep_device.EpOptions; Assert.NotNull(options); + var memInfo = ep_device.GetMemoryInfo(OrtDeviceMemoryType.DEFAULT); + Assert.NotNull(memInfo); ReadHardwareDeviceValues(ep_device.HardwareDevice); } } @@ -77,14 +79,17 @@ public void RegisterUnregisterLibrary() // register. shouldn't throw ortEnvInstance.RegisterExecutionProviderLibrary(epName, libFullPath); - - // check OrtEpDevice was found - var epDevices = ortEnvInstance.GetEpDevices(); - var found = epDevices.Any(d => string.Equals(epName, d.EpName, StringComparison.OrdinalIgnoreCase)); - Assert.True(found); - - // unregister - ortEnvInstance.UnregisterExecutionProviderLibrary(epName); + try + { + // check OrtEpDevice was found + var epDevices = ortEnvInstance.GetEpDevices(); + var found = epDevices.Any(d => string.Equals(epName, d.EpName, StringComparison.OrdinalIgnoreCase)); + Assert.True(found); + } + finally + { // unregister + ortEnvInstance.UnregisterExecutionProviderLibrary(epName); + } } } diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtEnvTests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtEnvTests.cs index 229d683c162fd..ae4fb0cf164cd 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtEnvTests.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtEnvTests.cs @@ -1,7 +1,17 @@ -using System; +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using Microsoft.ML.OnnxRuntime.Tensors; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Linq.Expressions; +using System.Runtime.ConstrainedExecution; +using System.Runtime.InteropServices; using Xunit; - namespace Microsoft.ML.OnnxRuntime.Tests { /// @@ -212,5 +222,275 @@ public void TestEnvWithCustomLoggerAndThredingOptions() } } } + + [Collection("Ort Inference Tests")] + public class OrtEnvSharedAllocatorsTests + { + private void ValidateRunResultData(Tensor resultTensor, float[] expectedOutput, int[] expectedDimensions) + { + Assert.Equal(expectedDimensions.Length, resultTensor.Rank); + + var resultDimensions = resultTensor.Dimensions; + for (int i = 0; i < expectedDimensions.Length; i++) + { + Assert.Equal(expectedDimensions[i], resultDimensions[i]); + } + + var resultArray = new float[resultTensor.Length]; + for (int i = 0; i < resultTensor.Length; i++) + { + resultArray[i] = resultTensor.GetValue(i); + } + Assert.Equal(expectedOutput.Length, resultArray.Length); + Assert.Equal(expectedOutput, resultArray, new FloatComparer()); + } + + [Fact(DisplayName = "TestSharedAllocatorUsingCreateAndRegisterAllocator")] + private void TestSharedAllocatorUsingCreateAndRegisterAllocator() + { + var model = TestDataLoader.LoadModelFromEmbeddedResource("mul_1.onnx"); + + using var memInfo = new OrtMemoryInfo(OrtMemoryInfo.allocatorCPU, + OrtAllocatorType.ArenaAllocator, 0, OrtMemType.Default); + using var arenaCfg = new OrtArenaCfg(0, -1, -1, -1); + var env = OrtEnv.Instance(); + // Create and register the arena based allocator + env.CreateAndRegisterAllocator(memInfo, arenaCfg); + try + { + using var sessionOptions = new SessionOptions(); + // Key must match kOrtSessionOptionsConfigUseEnvAllocators in onnxruntime_session_options_config_keys.h + sessionOptions.AddSessionConfigEntry("session.use_env_allocators", "1"); + + // Create two sessions to share the allocator + // Create a third session that DOES NOT use the allocator in the environment + using var session1 = new InferenceSession(model, sessionOptions); + using var session2 = new InferenceSession(model, sessionOptions); + using var session3 = new InferenceSession(model); // Use the default SessionOptions instance + // Input data + var inputDims = new long[] { 3, 2 }; + var input = new float[] { 1.0F, 2.0F, 3.0F, 4.0F, 5.0F, 6.0F }; + + // Output data + int[] outputDims = { 3, 2 }; + float[] output = { 1.0F, 4.0F, 9.0F, 16.0F, 25.0F, 36.0F }; + + // Run inference on all three models + var inputMeta = session1.InputMetadata; + var container = new List(); + + foreach (var name in inputMeta.Keys) + { + Assert.Equal(typeof(float), inputMeta[name].ElementType); + Assert.True(inputMeta[name].IsTensor); + var tensor = new DenseTensor(input, inputMeta[name].Dimensions); + container.Add(NamedOnnxValue.CreateFromTensor(name, tensor)); + } + + // Run inference with named inputs and outputs created with in Run() + using var results = session1.Run(container); // results is an IReadOnlyList container + foreach (var r in results) + { + ValidateRunResultData(r.AsTensor(), output, outputDims); + } + + // Run inference with named inputs and outputs created with in Run() + using var results2 = session2.Run(container); // results is an IReadOnlyList container + foreach (var r in results2) + { + ValidateRunResultData(r.AsTensor(), output, outputDims); + } + + // Run inference with named inputs and outputs created with in Run() + using var results3 = session3.Run(container); // results is an IReadOnlyList container + foreach (var r in results3) + { + ValidateRunResultData(r.AsTensor(), output, outputDims); + } + } + finally + { + // Unregister the allocator + env.UnregisterAllocator(memInfo); + } + } + + [Fact(DisplayName = "TestSharedAllocatorUsingCreateAndRegisterAllocatorV2")] + private void TestSharedAllocatorUsingCreateAndRegisterAllocatorV2() + { + var model = TestDataLoader.LoadModelFromEmbeddedResource("mul_1.onnx"); + + using var memInfo = new OrtMemoryInfo(OrtMemoryInfo.allocatorCPU, + OrtAllocatorType.ArenaAllocator, 0, OrtMemType.Default); + using var arenaCfg = new OrtArenaCfg(0, -1, -1, -1); + var env = OrtEnv.Instance(); + + // Fill in with two arbitrary key-value pairs + var options = new Dictionary() { + { "key1", "value1" }, + { "key2", "value2" } + }; + + // Simply execute CreateAndRegisterAllocatorV2 to verify that C# API works as expected + env.CreateAndRegisterAllocator("CPUExecutionProvider", memInfo, arenaCfg, options); + try + { + using var sessionOptions = new SessionOptions(); + // Key must match kOrtSessionOptionsConfigUseEnvAllocators in onnxruntime_session_options_config_keys.h + sessionOptions.AddSessionConfigEntry("session.use_env_allocators", "1"); + using var session = new InferenceSession(model, sessionOptions); + } + finally + { + // Unregister the allocator + env.UnregisterAllocator(memInfo); + } + } + [Fact(DisplayName = "TestCreateGetReleaseSharedAllocator")] + private void TestCreateGetReleaseSharedAllocator() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + var env = OrtEnv.Instance(); + string libFullPath = Path.Combine(Directory.GetCurrentDirectory(), "example_plugin_ep.dll"); + Assert.True(File.Exists(libFullPath), $"Expected lib {libFullPath} does not exist."); + + // example plugin ep uses the registration name as the ep name + const string epName = "csharp_ep"; + + env.RegisterExecutionProviderLibrary(epName, libFullPath); + try + { + // Find OrtEpDevice for the example EP + OrtEpDevice epDevice = null; + var epDevices = env.GetEpDevices(); + foreach (var d in epDevices) + { + if (string.Equals(epName, d.EpName, StringComparison.OrdinalIgnoreCase)) + { + epDevice = d; + } + } + Assert.NotNull(epDevice); + + using var epMemoryInfo = epDevice.GetMemoryInfo(OrtDeviceMemoryType.DEFAULT); + + var options = new Dictionary() { + { "arena.initial_chunk_size_bytes", "25600" }, + }; + + // Strictly speaking the allocator is owned by the env + // but we want to dispose the C# object anyway + using var sharedAllocator = env.CreateSharedAllocator(epDevice, + OrtDeviceMemoryType.DEFAULT, + OrtAllocatorType.DeviceAllocator, + options); + + try + { + using var getAllocator = env.GetSharedAllocator(epMemoryInfo); + Assert.NotNull(getAllocator); + } + finally + { + // ReleaseSharedAllocator is a no-op if the allocator was created with CreateAndRegisterAllocator + env.ReleaseSharedAllocator(epDevice, OrtDeviceMemoryType.DEFAULT); + } + } + finally + { + env.UnregisterExecutionProviderLibrary(epName); + } + } + } + + [Fact(DisplayName = "TestCopyTensors")] + void TestCopyTensors() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + var env = OrtEnv.Instance(); + string libFullPath = Path.Combine(Directory.GetCurrentDirectory(), "example_plugin_ep.dll"); + Assert.True(File.Exists(libFullPath), $"Expected lib {libFullPath} does not exist."); + + // example plugin ep uses the registration name as the ep name + const string epName = "csharp_ep"; + + env.RegisterExecutionProviderLibrary(epName, libFullPath); + try + { + // Find the example device + OrtEpDevice epDevice = null; + var epDevices = env.GetEpDevices(); + foreach (var d in epDevices) + { + if (string.Equals(epName, d.EpName, StringComparison.OrdinalIgnoreCase)) + { + epDevice = d; + } + } + Assert.NotNull(epDevice); + + using var syncStream = epDevice.CreateSyncStream(null); + Assert.NotNull(syncStream); + // This returned Zero for example EP + // therefore do not assert for zero. + var streamHandle = syncStream.GetHandle(); + // Assert.NotEqual(IntPtr.Zero, streamHandle); + + var inputDims = new long[] { 3, 2 }; + float[] inputData1 = [1.0F, 2.0F, 3.0F, 4.0F, 5.0F, 6.0F]; + long[] inputData2 = [1, 2, 3, 4, 5, 6]; + + // Create source OrtValues on CPU on top of inputData + using var inputList = new DisposableListTest(2) + { + OrtValue.CreateTensorValueFromMemory(inputData1, inputDims), + OrtValue.CreateTensorValueFromMemory(inputData2, inputDims) + }; + + using var epMemoryInfo = epDevice.GetMemoryInfo(OrtDeviceMemoryType.DEFAULT); + var options = new Dictionary() { + { "arena.initial_chunk_size_bytes", "25600" }, + }; + + // Strictly speaking the allocator is owned by the env + // but we want to dispose the C# object anyway + using var sharedAllocator = env.CreateSharedAllocator(epDevice, + OrtDeviceMemoryType.DEFAULT, + OrtAllocatorType.DeviceAllocator, + options); + try + { + // Create destination empty OrtValues on the example EP device + using var outputList = new DisposableListTest(2) + { + OrtValue.CreateAllocatedTensorValue(sharedAllocator, + TensorElementType.Float, inputDims), + OrtValue.CreateAllocatedTensorValue(sharedAllocator, + TensorElementType.Int64, inputDims) + }; + + env.CopyTensors(inputList, outputList, syncStream); + + // Assert.Equal data on inputList and outputList + Assert.Equal(inputList[0].GetTensorDataAsSpan(), + outputList[0].GetTensorDataAsSpan()); + Assert.Equal(inputList[1].GetTensorDataAsSpan(), + outputList[1].GetTensorDataAsSpan()); + } + finally + { + // Unregister from the env + env.ReleaseSharedAllocator(epDevice, OrtDeviceMemoryType.DEFAULT); + } + } + finally + { + env.UnregisterExecutionProviderLibrary(epName); + } + } + } + } } diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs index eab4a3d412898..89dbce05326b5 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs @@ -42,33 +42,38 @@ public partial class InferenceTest public void CanCreateAndDisposeSessionWithModelPath() { string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "squeezenet.onnx"); - using (var session = new InferenceSession(modelPath)) + using var session = new InferenceSession(modelPath); + Assert.NotNull(session); + Assert.NotNull(session.InputMetadata); + Assert.Single(session.InputMetadata); // 1 input nodeMeta + Assert.True(session.InputMetadata.ContainsKey("data_0")); // input nodeMeta name + Assert.Equal(typeof(float), session.InputMetadata["data_0"].ElementType); + Assert.True(session.InputMetadata["data_0"].IsTensor); + var expectedInputDimensions = new int[] { 1, 3, 224, 224 }; + Assert.Equal(expectedInputDimensions.Length, session.InputMetadata["data_0"].Dimensions.Length); + for (int i = 0; i < expectedInputDimensions.Length; i++) { - Assert.NotNull(session); - Assert.NotNull(session.InputMetadata); - Assert.Single(session.InputMetadata); // 1 input nodeMeta - Assert.True(session.InputMetadata.ContainsKey("data_0")); // input nodeMeta name - Assert.Equal(typeof(float), session.InputMetadata["data_0"].ElementType); - Assert.True(session.InputMetadata["data_0"].IsTensor); - var expectedInputDimensions = new int[] { 1, 3, 224, 224 }; - Assert.Equal(expectedInputDimensions.Length, session.InputMetadata["data_0"].Dimensions.Length); - for (int i = 0; i < expectedInputDimensions.Length; i++) - { - Assert.Equal(expectedInputDimensions[i], session.InputMetadata["data_0"].Dimensions[i]); - } + Assert.Equal(expectedInputDimensions[i], session.InputMetadata["data_0"].Dimensions[i]); + } - Assert.NotNull(session.OutputMetadata); - Assert.Single(session.OutputMetadata); // 1 output nodeMeta - Assert.True(session.OutputMetadata.ContainsKey("softmaxout_1")); // output nodeMeta name - Assert.Equal(typeof(float), session.OutputMetadata["softmaxout_1"].ElementType); - Assert.True(session.OutputMetadata["softmaxout_1"].IsTensor); - var expectedOutputDimensions = new int[] { 1, 1000, 1, 1 }; - Assert.Equal(expectedOutputDimensions.Length, session.OutputMetadata["softmaxout_1"].Dimensions.Length); - for (int i = 0; i < expectedOutputDimensions.Length; i++) - { - Assert.Equal(expectedOutputDimensions[i], session.OutputMetadata["softmaxout_1"].Dimensions[i]); - } + Assert.NotNull(session.OutputMetadata); + Assert.Single(session.OutputMetadata); // 1 output nodeMeta + Assert.True(session.OutputMetadata.ContainsKey("softmaxout_1")); // output nodeMeta name + Assert.Equal(typeof(float), session.OutputMetadata["softmaxout_1"].ElementType); + Assert.True(session.OutputMetadata["softmaxout_1"].IsTensor); + var expectedOutputDimensions = new int[] { 1, 1000, 1, 1 }; + Assert.Equal(expectedOutputDimensions.Length, session.OutputMetadata["softmaxout_1"].Dimensions.Length); + for (int i = 0; i < expectedOutputDimensions.Length; i++) + { + Assert.Equal(expectedOutputDimensions[i], session.OutputMetadata["softmaxout_1"].Dimensions[i]); } + + using var inputsMemoryInfos = session.GetMemoryInfosForInputs(); + Assert.Equal(session.InputNames.Count, inputsMemoryInfos.Count); + using var outputsMemoryInfos = session.GetMemoryInfosForOutputs(); + Assert.Equal(session.OutputNames.Count, outputsMemoryInfos.Count); + var inputsEpDevices = session.GetEpDeviceForInputs(); + Assert.Equal(session.InputNames.Count, inputsEpDevices.Count); } #if NET8_0_OR_GREATER @@ -154,7 +159,7 @@ public void InferenceSessionDisposedDotnetTensors() { Assert.Equal(typeof(float), inputMeta[name].ElementType); Assert.True(inputMeta[name].IsTensor); - var tensor = SystemNumericsTensors.Tensor.Create(inputData, inputMeta[name].Dimensions.Select(x => (nint) x).ToArray()); + var tensor = SystemNumericsTensors.Tensor.Create(inputData, inputMeta[name].Dimensions.Select(x => (nint)x).ToArray()); inputOrtValues.Add(new DisposableTestPair(name, OrtValue.CreateTensorValueFromSystemNumericsTensorObject(tensor))); } diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index a6d69198fadcc..8469d40a96df3 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -253,6 +253,8 @@ Do not modify directly.* |||[9, 12]|**T** = tensor(float)| |||[1, 8]|**T** = tensor(float)| |MelWeightMatrix|*in* num_mel_bins:**T1**
*in* dft_length:**T1**
*in* sample_rate:**T1**
*in* lower_edge_hertz:**T2**
*in* upper_edge_hertz:**T2**
*out* output:**T3**|17+|**T1** = tensor(int32), tensor(int64)
**T2** = tensor(float)
**T3** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|MemcpyFromHost|*in* X:**T**
*out* Y:**T**|1+|**T** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|MemcpyToHost|*in* X:**T**
*out* Y:**T**|1+|**T** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Min|*in* data_0:**T**
*out* min:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |||[8, 11]|**T** = tensor(double), tensor(float)| @@ -547,6 +549,7 @@ Do not modify directly.* |MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float)| |MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)
**T3** = tensor(float), tensor(float16), tensor(uint8)
**T4** = tensor(int32)| |MaxpoolWithMask|*in* X:**T**
*in* M:**tensor(int32)**
*out* Y:**T**|1+|**T** = tensor(float)| +|MoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* cache_indirection:**M**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* qk:**QK**|1+|**T** = tensor(float)| |MurmurHash3|*in* X:**T1**
*out* Y:**T2**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string), tensor(uint32), tensor(uint64)
**T2** = tensor(int32), tensor(uint32)| |NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)| diff --git a/docs/python/README.rst b/docs/python/README.rst index fdef200c1d0de..c23c194ed8132 100644 --- a/docs/python/README.rst +++ b/docs/python/README.rst @@ -8,6 +8,11 @@ For more information on ONNX Runtime, please see `aka.ms/onnxruntime >; /** \brief Wrapper around ::OrtSyncStream * */ -struct SyncStream : detail::Base { - explicit SyncStream(std::nullptr_t) {} ///< Create an empty SyncStream object, must be assigned a valid one to be used - explicit SyncStream(OrtSyncStream* p) : Base{p} {} ///< Take ownership of a pointer created by C API - void* GetHandle() const; ///< Wraps SyncStream_GetHandle + +namespace detail { +template +struct SyncStreamImpl : Base { + using B = Base; + using B::B; + // For some reason this is not a const method on the stream + void* GetHandle(); ///< Wraps SyncStream_GetHandle }; +} // namespace detail + +struct SyncStream : detail::SyncStreamImpl { + ///< Create an empty SyncStream object, must be assigned a valid one to be used + explicit SyncStream(std::nullptr_t) {} + ///< Take ownership of a pointer created by C API + explicit SyncStream(OrtSyncStream* p) : SyncStreamImpl{p} {} +}; + +using UnownedSyncStream = detail::SyncStreamImpl>; namespace detail { template diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 59979189eed0f..cb6448ad12a81 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -669,9 +669,12 @@ inline void KeyValuePairs::Remove(const char* key) { GetApi().RemoveKeyValuePair(this->p_, key); } -inline void* SyncStream::GetHandle() const { +namespace detail { +template +inline void* SyncStreamImpl::GetHandle() { return GetApi().SyncStream_GetHandle(this->p_); } +} // namespace detail namespace detail { template @@ -1582,11 +1585,13 @@ inline std::vector ConstSessionImpl::GetMemoryInfoForInputs( auto num_inputs = GetInputCount(); std::vector mem_infos; - mem_infos.resize(num_inputs); + if (num_inputs > 0) { + mem_infos.resize(num_inputs); - ThrowOnError(GetApi().SessionGetMemoryInfoForInputs(this->p_, - reinterpret_cast(mem_infos.data()), - num_inputs)); + ThrowOnError(GetApi().SessionGetMemoryInfoForInputs(this->p_, + reinterpret_cast(mem_infos.data()), + num_inputs)); + } return mem_infos; } @@ -1598,11 +1603,13 @@ inline std::vector ConstSessionImpl::GetMemoryInfoForOutputs auto num_outputs = GetOutputCount(); std::vector mem_infos; - mem_infos.resize(num_outputs); + if (num_outputs > 0) { + mem_infos.resize(num_outputs); - ThrowOnError(GetApi().SessionGetMemoryInfoForOutputs(this->p_, - reinterpret_cast(mem_infos.data()), - num_outputs)); + ThrowOnError(GetApi().SessionGetMemoryInfoForOutputs(this->p_, + reinterpret_cast(mem_infos.data()), + num_outputs)); + } return mem_infos; } @@ -1631,12 +1638,12 @@ template inline std::vector ConstSessionImpl::GetEpDeviceForInputs() const { auto num_inputs = GetInputCount(); std::vector input_devices; - input_devices.resize(num_inputs); - - ThrowOnError(GetApi().SessionGetEpDeviceForInputs(this->p_, - reinterpret_cast(input_devices.data()), - num_inputs)); - + if (num_inputs > 0) { + input_devices.resize(num_inputs); + ThrowOnError(GetApi().SessionGetEpDeviceForInputs(this->p_, + reinterpret_cast(input_devices.data()), + num_inputs)); + } return input_devices; } diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 7eb5f7659a365..64a434e2fe301 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -408,3 +408,10 @@ static const char* const kOrtSessionOptionsDisableModelCompile = "session.disabl // Note: UNSUPPORTED models always fail regardless of this setting. static const char* const kOrtSessionOptionsFailOnSuboptimalCompiledModel = "session.fail_on_suboptimal_compiled_model"; + +// THIS OPTION IS NOT A REGULAR SESSION OPTION SINCE IT CAN BE MODIFIED AT ANY TIME +// Meant to be used with SetEpDynamicOptions +// options for HTP performance mode: "burst", "balanced", "default", "high_performance", +// "high_power_saver", "low_balanced", "extreme_power_saver", "low_power_saver", "power_saver", +// "sustained_high_performance". Default to "default". +static const char* const kOrtEpDynamicOptionsQnnHtpPerformanceMode = "ep.dynamic.qnn_htp_performance_mode"; diff --git a/js/common/lib/version.ts b/js/common/lib/version.ts index 994eb6f4300c1..f00f4ec8dee50 100644 --- a/js/common/lib/version.ts +++ b/js/common/lib/version.ts @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -export const version = '1.23.0'; +export const version = '1.23.1'; diff --git a/js/common/package-lock.json b/js/common/package-lock.json index 706f8b46a3ad4..9ef468a229788 100644 --- a/js/common/package-lock.json +++ b/js/common/package-lock.json @@ -1,12 +1,12 @@ { "name": "onnxruntime-common", - "version": "1.23.0", + "version": "1.23.1", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "onnxruntime-common", - "version": "1.23.0", + "version": "1.23.1", "license": "MIT", "devDependencies": { "typedoc": "^0.25.7" diff --git a/js/common/package.json b/js/common/package.json index a0eff9095e6d7..200aff42f8fca 100644 --- a/js/common/package.json +++ b/js/common/package.json @@ -2,7 +2,7 @@ "license": "MIT", "type": "module", "name": "onnxruntime-common", - "version": "1.23.0", + "version": "1.23.1", "repository": { "url": "https://github.com/Microsoft/onnxruntime.git", "type": "git" diff --git a/js/node/lib/version.ts b/js/node/lib/version.ts index 994eb6f4300c1..f00f4ec8dee50 100644 --- a/js/node/lib/version.ts +++ b/js/node/lib/version.ts @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -export const version = '1.23.0'; +export const version = '1.23.1'; diff --git a/js/node/package-lock.json b/js/node/package-lock.json index bd7e6cc1966c7..0a65eab39df70 100644 --- a/js/node/package-lock.json +++ b/js/node/package-lock.json @@ -1,12 +1,12 @@ { "name": "onnxruntime-node", - "version": "1.23.0", + "version": "1.23.1", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "onnxruntime-node", - "version": "1.23.0", + "version": "1.23.1", "hasInstallScript": true, "license": "MIT", "os": [ @@ -30,7 +30,7 @@ }, "../common": { "name": "onnxruntime-common", - "version": "1.23.0", + "version": "1.23.1", "license": "MIT", "devDependencies": { "typedoc": "^0.25.7" diff --git a/js/node/package.json b/js/node/package.json index 5520a48aa124a..1f29f2354b0d7 100644 --- a/js/node/package.json +++ b/js/node/package.json @@ -11,7 +11,7 @@ 6 ] }, - "version": "1.23.0", + "version": "1.23.1", "dependencies": { "adm-zip": "^0.5.16", "global-agent": "^3.0.0", diff --git a/js/node/script/install-metadata-versions.js b/js/node/script/install-metadata-versions.js index 3147f90904e7a..23df0b7ac96ed 100644 --- a/js/node/script/install-metadata-versions.js +++ b/js/node/script/install-metadata-versions.js @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -module.exports = { nuget: [{ feed: 'nuget', version: '1.23.0' }] }; +module.exports = { nuget: [{ feed: 'nuget', version: '1.23.1' }] }; diff --git a/js/react_native/lib/version.ts b/js/react_native/lib/version.ts index 994eb6f4300c1..f00f4ec8dee50 100644 --- a/js/react_native/lib/version.ts +++ b/js/react_native/lib/version.ts @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -export const version = '1.23.0'; +export const version = '1.23.1'; diff --git a/js/react_native/package-lock.json b/js/react_native/package-lock.json index ec2147b2cc4ba..f681b9166da98 100644 --- a/js/react_native/package-lock.json +++ b/js/react_native/package-lock.json @@ -1,12 +1,12 @@ { "name": "onnxruntime-react-native", - "version": "1.23.0", + "version": "1.23.1", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "onnxruntime-react-native", - "version": "1.23.0", + "version": "1.23.1", "license": "MIT", "dependencies": { "buffer": "^6.0.3", @@ -31,7 +31,7 @@ }, "../common": { "name": "onnxruntime-common", - "version": "1.23.0", + "version": "1.23.1", "license": "MIT", "devDependencies": { "typedoc": "^0.25.7" diff --git a/js/react_native/package.json b/js/react_native/package.json index 7a5ee35bdb25a..a88f5cf267aed 100644 --- a/js/react_native/package.json +++ b/js/react_native/package.json @@ -37,7 +37,7 @@ "registry": "https://registry.npmjs.org/" }, "source": "lib/index", - "version": "1.23.0", + "version": "1.23.1", "main": "dist/commonjs/index", "homepage": "https://github.com/microsoft/onnxruntime/blob/main/js/react_native/README.md", "files": [ diff --git a/js/web/lib/version.ts b/js/web/lib/version.ts index 994eb6f4300c1..f00f4ec8dee50 100644 --- a/js/web/lib/version.ts +++ b/js/web/lib/version.ts @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -export const version = '1.23.0'; +export const version = '1.23.1'; diff --git a/js/web/package-lock.json b/js/web/package-lock.json index eabb198e97177..74776abb25bd5 100644 --- a/js/web/package-lock.json +++ b/js/web/package-lock.json @@ -1,12 +1,12 @@ { "name": "onnxruntime-web", - "version": "1.23.0", + "version": "1.23.1", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "onnxruntime-web", - "version": "1.23.0", + "version": "1.23.1", "license": "MIT", "dependencies": { "flatbuffers": "^25.1.24", @@ -50,7 +50,7 @@ }, "../common": { "name": "onnxruntime-common", - "version": "1.23.0", + "version": "1.23.1", "license": "MIT", "devDependencies": { "typedoc": "^0.25.7" diff --git a/js/web/package.json b/js/web/package.json index 425aa88035424..db20202b4f24e 100644 --- a/js/web/package.json +++ b/js/web/package.json @@ -7,7 +7,7 @@ "type": "git" }, "author": "fs-eire", - "version": "1.23.0", + "version": "1.23.1", "jsdelivr": "dist/ort.min.js", "dependencies": { "flatbuffers": "^25.1.24", diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index 550502cf3bc48..ac25159802092 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -8,7 +8,7 @@ or the `Github project `_. """ -__version__ = "1.23.0" +__version__ = "1.23.1" __author__ = "Microsoft" # we need to do device version validation (for example to check Cuda version for an onnxruntime-training package). @@ -31,14 +31,17 @@ OrtAllocatorType, # noqa: F401 OrtArenaCfg, # noqa: F401 OrtCompileApiFlags, # noqa: F401 + OrtDeviceMemoryType, # noqa: F401 OrtEpDevice, # noqa: F401 OrtExecutionProviderDevicePolicy, # noqa: F401 OrtExternalInitializerInfo, # noqa: F401 OrtHardwareDevice, # noqa: F401 OrtHardwareDeviceType, # noqa: F401 OrtMemoryInfo, # noqa: F401 + OrtMemoryInfoDeviceType, # noqa: F401 OrtMemType, # noqa: F401 OrtSparseFormat, # noqa: F401 + OrtSyncStream, # noqa: F401 RunOptions, # noqa: F401 SessionIOBinding, # noqa: F401 SessionOptions, # noqa: F401 @@ -78,6 +81,7 @@ OrtDevice, # noqa: F401 OrtValue, # noqa: F401 SparseTensor, # noqa: F401 + copy_tensors, # noqa: F401 ) # TODO: thiagofc: Temporary experimental namespace for new PyTorch front-end diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index 34410a5f42630..d959d11e3fd43 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -108,6 +108,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QGemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, QMoE); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QMoE); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MoE); // ******** End: Quantization ******************* // #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED @@ -275,6 +276,7 @@ Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h b/onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h index eae96c186d471..84580b310f6b3 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h +++ b/onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h @@ -6,7 +6,7 @@ #include "core/common/common.h" #include "core/framework/tensor_shape.h" #include "core/framework/op_kernel.h" -#include "contrib_ops/cpu/moe/moe_helper.h" +#include "moe_helper.h" #include namespace onnxruntime { diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc new file mode 100644 index 0000000000000..a23ea07ac1cb8 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc @@ -0,0 +1,611 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cpu/moe/moe_cpu.h" +#include "contrib_ops/cpu/moe/moe_utils.h" +#include "contrib_ops/cpu/moe/moe_helper.h" +#include "core/framework/op_kernel.h" +#include "core/providers/common.h" +#include "core/providers/cpu/math/gemm_helper.h" +#include "core/util/math_cpuonly.h" +#include "core/mlas/inc/mlas.h" +#include "core/framework/float16.h" +#include "core/framework/allocator.h" +#include "core/platform/threadpool.h" +#include "core/common/narrow.h" + +#include +#include +#include + +namespace onnxruntime { +namespace contrib { + +template +MoE::MoE(const OpKernelInfo& op_kernel_info) : OpKernel(op_kernel_info), MoEBaseCPU(op_kernel_info) { + if (activation_type_ == ActivationType::SwiGLU && swiglu_fusion_ != 1) { + ORT_THROW("CPU MoE only supports interleaved SwiGLU format. Please set swiglu_fusion=1."); + } +} + +template +Status MoE::Compute(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + const Tensor* router_probs = context->Input(1); + const Tensor* fc1_experts_weights = context->Input(2); + const Tensor* fc1_experts_bias = context->Input(3); + const Tensor* fc2_experts_weights = context->Input(4); + const Tensor* fc2_experts_bias = context->Input(5); + const Tensor* fc3_experts_weights = context->Input(6); + const Tensor* fc3_experts_bias = context->Input(7); + + // FC3 not supported + if (fc3_experts_weights != nullptr || fc3_experts_bias != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "FC3 is not implemented for CPU MoE."); + } + + MoEParameters moe_params; + ORT_RETURN_IF_ERROR(moe_helper::CheckInputs( + moe_params, input, router_probs, + fc1_experts_weights, fc1_experts_bias, nullptr, + fc2_experts_weights, fc2_experts_bias, nullptr, + fc3_experts_weights, fc3_experts_bias, nullptr, + 1, + activation_type_ == ActivationType::SwiGLU)); + + Tensor* output = context->Output(0, input->Shape()); + + return ComputeMoE(context, input, router_probs, fc1_experts_weights, fc1_experts_bias, + fc2_experts_weights, fc2_experts_bias, output); +} + +template +Status MoE::ComputeMoE(const OpKernelContext* context, + const Tensor* input, + const Tensor* router_probs, + const Tensor* fc1_experts_weights, + const Tensor* fc1_experts_bias, + const Tensor* fc2_experts_weights, + const Tensor* fc2_experts_bias, + Tensor* output) const { + const auto& input_shape = input->Shape(); + const auto& router_shape = router_probs->Shape(); + const auto& fc2_shape = fc2_experts_weights->Shape(); + + const int64_t num_tokens = input_shape.Size() / input_shape[input_shape.NumDimensions() - 1]; + const int64_t hidden_size = input_shape[input_shape.NumDimensions() - 1]; + const int64_t num_experts = router_shape[1]; + const int64_t inter_size = (fc2_shape[1] * fc2_shape[2]) / hidden_size; + const bool is_swiglu = activation_type_ == ActivationType::SwiGLU; + const int64_t fc1_output_size = is_swiglu ? (inter_size * 2) : inter_size; + + const T* input_data = input->Data(); + const T* router_data = router_probs->Data(); + const T* fc1_weights_data = fc1_experts_weights->Data(); + const T* fc1_bias_data = fc1_experts_bias ? fc1_experts_bias->Data() : nullptr; + const T* fc2_weights_data = fc2_experts_weights->Data(); + const T* fc2_bias_data = fc2_experts_bias ? fc2_experts_bias->Data() : nullptr; + T* output_data = output->MutableData(); + + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + + const T* input_data_to_use = input_data; + IAllocatorUniquePtr input_data_copy_ptr; + if (normalize_routing_weights_) { + input_data_copy_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * hidden_size)); + T* input_data_copy = input_data_copy_ptr.get(); + std::copy(input_data, input_data + (num_tokens * hidden_size), input_data_copy); + input_data_to_use = input_data_copy; + } + + std::fill_n(output_data, output->Shape().Size(), T{}); + + IAllocatorUniquePtr router_logits_float_buffer; + const float* router_logits_float = nullptr; + if constexpr (std::is_same_v) { + router_logits_float_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * num_experts)); + router_logits_float = router_logits_float_buffer.get(); + MlasConvertHalfToFloatBuffer(reinterpret_cast(router_data), const_cast(router_logits_float), static_cast(num_tokens * num_experts)); + } else { + router_logits_float = reinterpret_cast(router_data); + } + + auto route_expert_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * k_)); + int* route_expert = route_expert_ptr.get(); + auto route_scale_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * k_)); + float* route_scale = route_scale_ptr.get(); + + auto* tp = context->GetOperatorThreadPool(); + int num_routing_threads = 1; + if (tp != nullptr && num_tokens >= 1024) { + int max_threads = concurrency::ThreadPool::DegreeOfParallelism(tp); + num_routing_threads = std::min(narrow(num_tokens / 512), max_threads); + num_routing_threads = std::max(1, num_routing_threads); + } + + std::vector>> thread_local_expert_token_maps(num_routing_threads); + for (auto& map : thread_local_expert_token_maps) { + map.resize(static_cast(num_experts)); + for (auto& expert_map : map) { + expert_map.reserve(static_cast(std::max(static_cast(1), num_tokens / num_experts / num_routing_threads * 2))); + } + } + + concurrency::ThreadPool::TrySimpleParallelFor(tp, num_routing_threads, [&](std::ptrdiff_t thread_id) { + auto work = concurrency::ThreadPool::PartitionWork(narrow(thread_id), num_routing_threads, static_cast(num_tokens)); + auto& local_expert_token_map = thread_local_expert_token_maps[thread_id]; + + std::vector> sorted_logits(static_cast(num_experts)); + std::vector full_softmax(static_cast(num_experts)); + + for (int64_t i = work.start; i < work.end; ++i) { + const float* logits = router_logits_float + i * num_experts; + + float max_logit = logits[0]; + for (int64_t j = 1; j < num_experts; ++j) { + max_logit = std::max(max_logit, logits[j]); + } + + float sum_exp = 0.0f; + for (int64_t j = 0; j < num_experts; ++j) { + full_softmax[static_cast(j)] = std::exp(logits[j] - max_logit); + sum_exp += full_softmax[static_cast(j)]; + } + + const float inv_sum_exp = 1.0f / sum_exp; + for (int64_t j = 0; j < num_experts; ++j) { + full_softmax[static_cast(j)] *= inv_sum_exp; + sorted_logits[static_cast(j)] = {full_softmax[static_cast(j)], j}; + } + + std::partial_sort(sorted_logits.begin(), sorted_logits.begin() + static_cast(k_), sorted_logits.end(), std::greater<>()); + + if (normalize_routing_weights_) { + float top_k_sum = 0.0f; + for (int64_t j = 0; j < k_; ++j) { + top_k_sum += sorted_logits[static_cast(j)].first; + } + const float inv_top_k_sum = 1.0f / top_k_sum; + + for (int64_t j = 0; j < k_; ++j) { + int64_t expert_idx = sorted_logits[static_cast(j)].second; + int64_t route_idx = i * k_ + j; + float normalized_weight = sorted_logits[static_cast(j)].first * inv_top_k_sum; + + route_expert[route_idx] = narrow(expert_idx); + route_scale[route_idx] = normalized_weight; + if (normalized_weight > 0.0f) { + local_expert_token_map[static_cast(expert_idx)].push_back(route_idx); + } + } + } else { + for (int64_t j = 0; j < k_; ++j) { + int64_t expert_idx = sorted_logits[static_cast(j)].second; + int64_t route_idx = i * k_ + j; + float weight = sorted_logits[static_cast(j)].first; + + route_expert[route_idx] = narrow(expert_idx); + route_scale[route_idx] = weight; + if (weight > 0.0f) { + local_expert_token_map[static_cast(expert_idx)].push_back(route_idx); + } + } + } + } + }); + + std::vector> expert_token_map(static_cast(num_experts)); + + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + size_t total_tokens_for_expert = 0; + for (int t = 0; t < num_routing_threads; ++t) { + total_tokens_for_expert += thread_local_expert_token_maps[t][static_cast(expert_idx)].size(); + } + expert_token_map[static_cast(expert_idx)].reserve(total_tokens_for_expert); + } + + for (int t = 0; t < num_routing_threads; ++t) { + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + auto& local_tokens = thread_local_expert_token_maps[t][static_cast(expert_idx)]; + if (!local_tokens.empty()) { + auto& expert_map = expert_token_map[static_cast(expert_idx)]; + expert_map.insert(expert_map.end(), + std::make_move_iterator(local_tokens.begin()), + std::make_move_iterator(local_tokens.end())); + } + } + } + + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + auto& expert_map = expert_token_map[static_cast(expert_idx)]; + if (!expert_map.empty()) { + std::sort(expert_map.begin(), expert_map.end()); + } + } + + IAllocatorUniquePtr input_float_buffer; + const float* input_float; + if constexpr (std::is_same_v) { + input_float_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * hidden_size)); + input_float = input_float_buffer.get(); + MlasConvertHalfToFloatBuffer(reinterpret_cast(input_data_to_use), const_cast(input_float), static_cast(num_tokens * hidden_size)); + } else { + input_float = reinterpret_cast(input_data_to_use); + } + + int num_expert_threads = 1; + if (tp != nullptr) { + int total_active_experts = 0; + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + if (!expert_token_map[static_cast(expert_idx)].empty()) { + total_active_experts++; + } + } + + if (total_active_experts > 0) { + int max_threads = concurrency::ThreadPool::DegreeOfParallelism(tp); + num_expert_threads = std::min(total_active_experts, max_threads); + num_expert_threads = std::min(num_expert_threads, 8); + } + } + + // Calculate maximum possible tokens per expert for buffer sizing + int64_t max_tokens_per_expert = 0; + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + const auto& routes = expert_token_map[static_cast(expert_idx)]; + max_tokens_per_expert = std::max(max_tokens_per_expert, static_cast(routes.size())); + } + + // Thread-local buffer pool for expert processing + struct ThreadLocalBuffers { + IAllocatorUniquePtr A1_buffer; + IAllocatorUniquePtr batch_weights_buffer; + IAllocatorUniquePtr token_ids_buffer; + IAllocatorUniquePtr A1_t_buffer; + IAllocatorUniquePtr C2_buffer; + // Additional buffers for ProcessExpertBatch to avoid repeated allocations + IAllocatorUniquePtr fc1_output_buffer; + IAllocatorUniquePtr activation_output_buffer; + int64_t current_capacity = 0; + int64_t current_fc1_capacity = 0; + int64_t current_activation_capacity = 0; + + void EnsureCapacity(AllocatorPtr& allocator, int64_t required_tokens, int64_t hidden_size, + int64_t fc1_output_size, int64_t inter_size) { + if (required_tokens > current_capacity) { + // Use high watermark approach - allocate more than needed for future reuse + int64_t new_capacity = std::max(required_tokens * 2, current_capacity + 512); + + A1_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(new_capacity * hidden_size)); + batch_weights_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(new_capacity)); + token_ids_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(new_capacity)); + A1_t_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(new_capacity * hidden_size)); + C2_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(new_capacity * hidden_size)); + + current_capacity = new_capacity; + } + + // Ensure ProcessExpertBatch buffers have sufficient capacity + int64_t required_fc1_capacity = required_tokens * fc1_output_size; + int64_t required_activation_capacity = required_tokens * inter_size; + + if (required_fc1_capacity > current_fc1_capacity) { + int64_t new_fc1_capacity = std::max(required_fc1_capacity * 2, current_fc1_capacity + (512 * fc1_output_size)); + fc1_output_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(new_fc1_capacity)); + current_fc1_capacity = new_fc1_capacity; + } + + if (required_activation_capacity > current_activation_capacity) { + int64_t new_activation_capacity = std::max(required_activation_capacity * 2, current_activation_capacity + (512 * inter_size)); + activation_output_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(new_activation_capacity)); + current_activation_capacity = new_activation_capacity; + } + } + }; + + // Pre-allocate thread-local buffer pools + std::vector thread_buffers(num_expert_threads); + for (int i = 0; i < num_expert_threads; ++i) { + thread_buffers[i].EnsureCapacity(allocator, max_tokens_per_expert, hidden_size, fc1_output_size, inter_size); + } + + const size_t output_buffer_size = static_cast(output->Shape().Size()); + auto thread_local_outputs_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_expert_threads) * output_buffer_size); + float* thread_local_outputs = thread_local_outputs_ptr.get(); + + // Initialize thread-local outputs with vectorized operation + std::fill_n(thread_local_outputs, static_cast(num_expert_threads) * output_buffer_size, 0.0f); + + // Optimized expert processing with thread-local buffer reuse + concurrency::ThreadPool::TrySimpleParallelFor(tp, num_expert_threads, [&](std::ptrdiff_t thread_id_pd) { + int thread_id = narrow(thread_id_pd); + auto work = concurrency::ThreadPool::PartitionWork(thread_id, num_expert_threads, static_cast(num_experts)); + + float* local_output = thread_local_outputs + static_cast(thread_id) * output_buffer_size; + ThreadLocalBuffers& buffers = thread_buffers[thread_id]; + + for (int64_t expert_idx = work.start; expert_idx < work.end; ++expert_idx) { + const auto& routes = expert_token_map[static_cast(expert_idx)]; + if (routes.empty()) continue; + + const int64_t num_expert_tokens = static_cast(routes.size()); + + // Ensure thread-local buffers have sufficient capacity + buffers.EnsureCapacity(allocator, num_expert_tokens, hidden_size, fc1_output_size, inter_size); + + // Use pre-allocated buffers from thread-local pool + float* A1 = buffers.A1_buffer.get(); + float* batch_weights = buffers.batch_weights_buffer.get(); + int64_t* token_ids = buffers.token_ids_buffer.get(); + T* A1_t = buffers.A1_t_buffer.get(); + T* C2 = buffers.C2_buffer.get(); + T* fc1_output = buffers.fc1_output_buffer.get(); + T* activation_output = buffers.activation_output_buffer.get(); + + // Optimized data gathering with better memory access patterns + for (int64_t r = 0; r < num_expert_tokens; ++r) { + int64_t route_idx = routes[static_cast(r)]; + int64_t token = route_idx / k_; + + token_ids[r] = token; + batch_weights[r] = route_scale[route_idx]; + + // Use SIMD-friendly copy for better performance + const float* src = input_float + token * hidden_size; + float* dst = A1 + static_cast(r) * static_cast(hidden_size); + std::copy(src, src + hidden_size, dst); + } + + const T* fc1_expert_weights = fc1_weights_data + expert_idx * fc1_output_size * hidden_size; + const T* fc1_expert_bias = fc1_bias_data ? fc1_bias_data + expert_idx * fc1_output_size : nullptr; + const T* fc2_expert_weights = fc2_weights_data + expert_idx * hidden_size * inter_size; + const T* fc2_expert_bias = fc2_bias_data ? fc2_bias_data + expert_idx * hidden_size : nullptr; + + // Convert input to T only when needed for computation + for (size_t i = 0; i < static_cast(num_expert_tokens * hidden_size); ++i) { + A1_t[i] = static_cast(A1[i]); + } + + ORT_IGNORE_RETURN_VALUE(ProcessExpertBatch(A1_t, token_ids, batch_weights, + num_expert_tokens, expert_idx, + fc1_expert_weights, fc1_expert_bias, + fc2_expert_weights, fc2_expert_bias, + C2, hidden_size, inter_size, + fc1_output, activation_output)); + + // Optimized output accumulation with vectorized operations + for (int64_t r = 0; r < num_expert_tokens; ++r) { + int64_t token = token_ids[r]; + const T* expert_output_t = C2 + static_cast(r) * static_cast(hidden_size); + float w = batch_weights[r]; + float* dest = local_output + static_cast(token) * static_cast(hidden_size); + + // Use explicit loop for better vectorization opportunities + for (int64_t j = 0; j < hidden_size; ++j) { + dest[j] += w * static_cast(expert_output_t[j]); + } + } + } + }); + + auto accumulate = [&](float* buffer) { + std::fill_n(buffer, output_buffer_size, 0.0f); + + for (size_t j = 0; j < output_buffer_size; ++j) { + double sum = 0.0; + double c = 0.0; + + for (int i = 0; i < num_expert_threads; ++i) { + const size_t thread_offset = static_cast(i) * output_buffer_size; + double y = static_cast(thread_local_outputs[thread_offset + j]) - c; + double t = sum + y; + c = (t - sum) - y; + sum = t; + } + buffer[j] = static_cast(sum); + } + }; + + if constexpr (std::is_same_v) { + auto final_output_float_ptr = IAllocator::MakeUniquePtr(allocator, output_buffer_size); + float* final_output_float = final_output_float_ptr.get(); + accumulate(final_output_float); + + MlasConvertFloatToHalfBuffer(final_output_float, + reinterpret_cast(output->MutableData()), + static_cast(output_buffer_size)); + } else { + auto final_output_float_ptr = IAllocator::MakeUniquePtr(allocator, output_buffer_size); + float* final_output_float = final_output_float_ptr.get(); + accumulate(final_output_float); + + float* out_ptr = reinterpret_cast(output->MutableData()); + memcpy(out_ptr, final_output_float, output_buffer_size * sizeof(float)); + } + return Status::OK(); +} +template +Status MoE::ProcessExpertBatch(const T* input_tokens, + const int64_t* token_expert_ids, + const float* token_weights, + int64_t batch_size, + int64_t expert_id, + const T* fc1_weights, + const T* fc1_bias, + const T* fc2_weights, + const T* fc2_bias, + T* output_buffer, + int64_t hidden_size, + int64_t inter_size, + T* fc1_output_buffer, + T* activation_output_buffer) const { + ORT_UNUSED_PARAMETER(token_expert_ids); + ORT_UNUSED_PARAMETER(token_weights); + ORT_UNUSED_PARAMETER(expert_id); + ORT_UNUSED_PARAMETER(fc1_output_buffer); + ORT_UNUSED_PARAMETER(activation_output_buffer); + const bool is_swiglu = activation_type_ == ActivationType::SwiGLU; + const int64_t fc1_output_size = is_swiglu ? (inter_size * 2) : inter_size; + + constexpr int64_t stack_threshold = 1024; + const bool use_stack = (batch_size * fc1_output_size) <= stack_threshold; + + std::vector fc1_output_vec; + std::vector activation_output_vec; + T* fc1_output; + T* activation_output; + + if (use_stack) { + fc1_output_vec.resize(static_cast(batch_size * fc1_output_size)); + activation_output_vec.resize(static_cast(batch_size * inter_size)); + fc1_output = fc1_output_vec.data(); + activation_output = activation_output_vec.data(); + } else { + fc1_output_vec.resize(static_cast(batch_size * fc1_output_size)); + activation_output_vec.resize(static_cast(batch_size * inter_size)); + fc1_output = fc1_output_vec.data(); + activation_output = activation_output_vec.data(); + } + + ORT_RETURN_IF_ERROR(ComputeGEMM(input_tokens, fc1_weights, fc1_output, + batch_size, hidden_size, fc1_output_size, true)); + + if (fc1_bias) { + for (int64_t batch = 0; batch < batch_size; ++batch) { + T* batch_output = fc1_output + batch * fc1_output_size; + // Explicit loop for better vectorization + for (int64_t i = 0; i < fc1_output_size; ++i) { + batch_output[i] = static_cast(static_cast(batch_output[i]) + + static_cast(fc1_bias[i])); + } + } + } + + if (is_swiglu) { + for (int64_t batch = 0; batch < batch_size; ++batch) { + ApplySwiGLUVectorized(fc1_output + batch * fc1_output_size, + activation_output + batch * inter_size, + inter_size); + } + } else { + ApplyActivationVectorized(fc1_output, batch_size * fc1_output_size); + std::copy(fc1_output, fc1_output + (batch_size * fc1_output_size), activation_output); + } + + ORT_RETURN_IF_ERROR(ComputeGEMM(activation_output, fc2_weights, output_buffer, + batch_size, inter_size, hidden_size, true)); + + if (fc2_bias) { + for (int64_t batch = 0; batch < batch_size; ++batch) { + T* batch_output = output_buffer + batch * hidden_size; + for (int64_t i = 0; i < hidden_size; ++i) { + batch_output[i] = static_cast(static_cast(batch_output[i]) + + static_cast(fc2_bias[i])); + } + } + } + + return Status::OK(); +} + +template <> +Status MoE::ComputeGEMM(const float* A, const float* B, float* C, + int64_t M, int64_t K, int64_t N, bool transpose_B) const { + MLAS_SGEMM_DATA_PARAMS params; + params.A = A; + params.lda = static_cast(K); + params.alpha = 1.0f; + params.beta = 0.0f; + params.C = C; + params.ldc = static_cast(N); + params.B = B; + + if (transpose_B) { + params.ldb = static_cast(K); + MlasGemm(CblasNoTrans, CblasTrans, static_cast(M), static_cast(N), static_cast(K), params, nullptr); + } else { + params.ldb = static_cast(N); + MlasGemm(CblasNoTrans, CblasNoTrans, static_cast(M), static_cast(N), static_cast(K), params, nullptr); + } + + return Status::OK(); +} + +template <> +Status MoE::ComputeGEMM(const MLFloat16* A, const MLFloat16* B, MLFloat16* C, + int64_t M, int64_t K, int64_t N, bool transpose_B) const { + MLAS_HALF_GEMM_DATA_PARAMS params; + params.A = A; + params.lda = static_cast(K); + params.C = C; + params.ldc = static_cast(N); + params.AIsfp32 = false; + params.BIsfp32 = false; + params.B = B; + + if (transpose_B) { + params.ldb = static_cast(K); + } else { + params.ldb = static_cast(N); + } + + MlasHalfGemmBatch(static_cast(M), static_cast(N), static_cast(K), 1, ¶ms, nullptr); + return Status::OK(); +} + +template +void MoE::ApplyActivationVectorized(T* data, int64_t size) const { + for (int64_t i = 0; i < size; ++i) { + float val = static_cast(data[i]); + data[i] = static_cast(ApplyActivation(val, activation_type_)); + } +} + +template +void MoE::ApplySwiGLUVectorized(const T* input, T* output, int64_t size) const { + for (int64_t i = 0; i < size; ++i) { + float gate = static_cast(input[2 * i]); + float linear = static_cast(input[2 * i + 1]); + + gate = std::min(gate, swiglu_limit_); + linear = std::clamp(linear, -swiglu_limit_, swiglu_limit_); + + float sigmoid_arg = activation_alpha_ * gate; + float sigmoid_out; + if (sigmoid_arg > 0) { + float exp_neg = std::exp(-sigmoid_arg); + sigmoid_out = 1.0f / (1.0f + exp_neg); + } else { + float exp_pos = std::exp(sigmoid_arg); + sigmoid_out = exp_pos / (1.0f + exp_pos); + } + + float swish_out = gate * sigmoid_out; + output[i] = static_cast(swish_out * (linear + activation_beta_)); + } +} + +template <> +void MoE::ApplySwiGLUVectorized(const float* input, float* output, int64_t size) const { + ApplySwiGLUActivation(input, output, size, true, + activation_alpha_, activation_beta_, swiglu_limit_); +} + +template class MoE; +template class MoE; + +#define REGISTER_KERNEL_TYPED(type) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + MoE, kMSDomain, 1, type, kCpuExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .MayInplace(0, 0) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + MoE); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_cpu.h b/onnxruntime/contrib_ops/cpu/moe/moe_cpu.h new file mode 100644 index 0000000000000..60d8217015b5b --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/moe/moe_cpu.h @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "contrib_ops/cpu/moe/moe_base_cpu.h" + +namespace onnxruntime { +namespace contrib { + +template +class MoE final : public OpKernel, public MoEBaseCPU { + public: + explicit MoE(const OpKernelInfo& op_kernel_info); + Status Compute(OpKernelContext* context) const override; + + private: + Status ComputeMoE(const OpKernelContext* context, + const Tensor* input, + const Tensor* router_probs, + const Tensor* fc1_experts_weights, + const Tensor* fc1_experts_bias, + const Tensor* fc2_experts_weights, + const Tensor* fc2_experts_bias, + Tensor* output) const; + + Status ProcessExpertBatch(const T* input_tokens, + const int64_t* token_expert_ids, + const float* token_weights, + int64_t num_tokens, + int64_t expert_id, + const T* fc1_weights, + const T* fc1_bias, + const T* fc2_weights, + const T* fc2_bias, + T* output_buffer, + int64_t hidden_size, + int64_t inter_size, + T* fc1_output_buffer, + T* activation_output_buffer) const; + + Status ComputeGEMM(const T* A, const T* B, T* C, + int64_t M, int64_t K, int64_t N, + bool transpose_B = false) const; + + void ApplyActivationVectorized(T* data, int64_t size) const; + void ApplySwiGLUVectorized(const T* input, T* output, int64_t size) const; +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_helper.h b/onnxruntime/contrib_ops/cpu/moe/moe_helper.h index e494719464d20..39249f842e632 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_helper.h +++ b/onnxruntime/contrib_ops/cpu/moe/moe_helper.h @@ -49,7 +49,8 @@ Status CheckInputs(MoEParameters& parameters, const Tensor* fc3_experts_bias, // optional const Tensor* fc3_experts_scales, // required for qMoE; NULL for MOE const int64_t pack_size, // number of weights packed together (like 2 for uint4 packed to uint8) - const bool is_fused_swiglu) { + const bool is_fused_swiglu, + const int64_t block_size = 0) { // block size for block-wise quantization // Check dimensions of input to avoid input_dims index out of range. CHECK_TENSOR_SHAPE will verify each tensor later. ASSERT_TENSOR_2D_OR_3D(input); ASSERT_TENSOR_3D(fc1_experts_weights); @@ -90,9 +91,63 @@ Status CheckInputs(MoEParameters& parameters, CHECK_TENSOR_SHAPE(fc2_experts_bias, num_experts, hidden_size); CHECK_TENSOR_SHAPE(fc3_experts_bias, num_experts, inter_size); - CHECK_TENSOR_SHAPE(fc1_experts_scales, num_experts, fc1_inter_size); - CHECK_TENSOR_SHAPE(fc2_experts_scales, num_experts, hidden_size); - CHECK_TENSOR_SHAPE(fc3_experts_scales, num_experts, inter_size); + // Validate scale tensors: Handle both row-wise and block-wise quantization flexibly + // First, detect the actual quantization method from the tensor shapes + bool is_row_wise_quantization = true; + if (fc1_experts_scales != nullptr) { + const auto& fc1_scales_dims = fc1_experts_scales->Shape().GetDims(); + if (fc1_scales_dims.size() == 3 && fc1_scales_dims[2] > 1) { + is_row_wise_quantization = false; + } + } + + if (block_size > 0 && !is_row_wise_quantization) { + // Block-wise quantization: 3D scale tensors + // For block-wise quantization, we calculate the number of blocks using ceiling division + // to handle cases where the dimension is not perfectly divisible by block_size + const int64_t fc1_blocks_per_row = (hidden_size + block_size - 1) / block_size; + const int64_t fc2_blocks_per_row = (inter_size + block_size - 1) / block_size; + const int64_t fc3_blocks_per_row = (hidden_size + block_size - 1) / block_size; + + CHECK_TENSOR_SHAPE(fc1_experts_scales, num_experts, fc1_inter_size, fc1_blocks_per_row); + CHECK_TENSOR_SHAPE(fc2_experts_scales, num_experts, hidden_size, fc2_blocks_per_row); + CHECK_TENSOR_SHAPE(fc3_experts_scales, num_experts, inter_size, fc3_blocks_per_row); + } else { + // Row-wise quantization: 2D scale tensors or 3D with last dimension = 1 + // Handle both {num_experts, features} and {num_experts, features, 1} shapes + if (fc1_experts_scales != nullptr) { + const auto& fc1_scales_dims = fc1_experts_scales->Shape().GetDims(); + if (fc1_scales_dims.size() == 2) { + CHECK_TENSOR_SHAPE(fc1_experts_scales, num_experts, fc1_inter_size); + } else if (fc1_scales_dims.size() == 3) { + CHECK_TENSOR_SHAPE(fc1_experts_scales, num_experts, fc1_inter_size, 1); + } else { + ORT_THROW("fc1_experts_scales must be 2D or 3D tensor"); + } + } + + if (fc2_experts_scales != nullptr) { + const auto& fc2_scales_dims = fc2_experts_scales->Shape().GetDims(); + if (fc2_scales_dims.size() == 2) { + CHECK_TENSOR_SHAPE(fc2_experts_scales, num_experts, hidden_size); + } else if (fc2_scales_dims.size() == 3) { + CHECK_TENSOR_SHAPE(fc2_experts_scales, num_experts, hidden_size, 1); + } else { + ORT_THROW("fc2_experts_scales must be 2D or 3D tensor"); + } + } + + if (fc3_experts_scales != nullptr) { + const auto& fc3_scales_dims = fc3_experts_scales->Shape().GetDims(); + if (fc3_scales_dims.size() == 2) { + CHECK_TENSOR_SHAPE(fc3_experts_scales, num_experts, inter_size); + } else if (fc3_scales_dims.size() == 3) { + CHECK_TENSOR_SHAPE(fc3_experts_scales, num_experts, inter_size, 1); + } else { + ORT_THROW("fc3_experts_scales must be 2D or 3D tensor"); + } + } + } if (fc3_experts_weights == nullptr) { ORT_ENFORCE(fc3_experts_bias == nullptr && fc3_experts_scales == nullptr); diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc index 5c6c3b919b572..8a3c3f6d9f37a 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc @@ -2,12 +2,17 @@ // Licensed under the MIT License. #include "contrib_ops/cpu/moe/moe_quantization_cpu.h" - #include "core/framework/allocator.h" #include "core/framework/float16.h" #include "core/mlas/inc/mlas.h" +#include "core/mlas/inc/mlas_q4.h" #include "core/platform/threadpool.h" #include "core/providers/cpu/math/gemm_helper.h" +#include "core/providers/cpu/activation/activations.h" +#include "core/common/safeint.h" +#include "core/common/narrow.h" +#include "core/framework/tensor_type_and_shape.h" +#include "core/util/math.h" #include "contrib_ops/cpu/moe/moe_utils.h" #include "contrib_ops/cpu/moe/moe_helper.h" @@ -17,44 +22,326 @@ #include #include +namespace { +inline int64_t GetOptimalBlockSize(int64_t total_elements, int num_threads) { + if (total_elements <= 0 || num_threads <= 0) return 64; + const int64_t l1_cache_elements = 8192; // ~32KB / 4 bytes per float + const int64_t divisor = std::max(1, num_threads > 1 ? 4 : 2); + const int64_t base_block_size = l1_cache_elements / divisor; + const int64_t max_block = std::max(int64_t{32}, total_elements / std::max(int64_t{1}, int64_t{4})); + return std::clamp(base_block_size, int64_t{32}, std::min(int64_t{512}, max_block)); +} + +inline int64_t GetUnrollFactor(int64_t vector_size) { + if (vector_size <= 0) return 2; + if (vector_size >= 512) return 16; + if (vector_size >= 128) return 8; + if (vector_size >= 32) return 4; + return 2; +} + +inline bool ShouldUseMemcpy(int64_t size) { + return size >= 64; +} + +inline int64_t GetDequantBlockSize(int64_t features, int64_t total_work) { + if (features <= 0 || total_work <= 0) return 16; + const int64_t target_block_size = std::max(int64_t{16}, features / std::max(int64_t{1}, int64_t{8})); + const int64_t work_based_size = std::max(int64_t{16}, total_work / std::max(int64_t{1}, int64_t{4})); + return std::min(target_block_size, work_based_size); +} + +bool CanUseMlasQ4Dequant(int64_t num_bits) { + if (num_bits != 4) { + return false; + } + + return true; +} + +bool CanUseMlasQ4Gemm(int64_t expert_weight_bits, int64_t block_size, + int64_t rows, int64_t cols, MLAS_BLK_QUANT_TYPE& out_qtype) { + if (expert_weight_bits != 4) { + return false; + } + + if (block_size == 64) { + out_qtype = BlkQ4Sym64; + } else if (block_size == 128) { + out_qtype = BlkQ4Sym128; + } else if (block_size == 0) { + out_qtype = BlkQ4Sym; + } else { + return false; + } + + size_t expected_size = MlasQ4GemmPackBSize(out_qtype, static_cast(cols), static_cast(rows)); + return expected_size > 0; +} + +} // namespace + namespace onnxruntime { namespace contrib { -// Helper function to dequantize weights. Supports 4-bit and 8-bit symmetric quantization. -// The source quantized weights are stored as a row-major representation of the transposed -// logical weight matrix (W^T). This function dequantizes it into a float row-major W^T matrix. template -void DequantizeBlock(const uint8_t* quantized_data, - const TScale* scales, - int64_t /*block_size*/, - int64_t num_bits, - int64_t rows, - int64_t cols, - float* dequantized_data) { +void DequantizeBlockWithMlas(const uint8_t* quantized_data, + const TScale* scales, + int64_t block_size, + int64_t num_bits, + int64_t rows, + int64_t cols, + float* dequantized_data, + MLAS_THREADPOOL* thread_pool); + +template +Status ConvertToMlasQ4Format(const uint8_t* quantized_data, + const TScale* scales, + int64_t block_size, + int64_t num_bits, + int64_t rows, + int64_t cols, + MLAS_BLK_QUANT_TYPE qtype, + AllocatorPtr allocator, + IAllocatorUniquePtr& mlas_packed_buffer) { + if (num_bits != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Only 4-bit quantization supported for MLAS Q4 format conversion"); + } + + auto temp_float_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(rows * cols)); + float* temp_float = temp_float_buffer.get(); + + DequantizeBlockWithMlas(quantized_data, scales, block_size, num_bits, rows, cols, temp_float, nullptr); + + size_t packed_size = MlasQ4GemmPackBSize(qtype, static_cast(cols), static_cast(rows)); + if (packed_size == 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "MLAS Q4 packing not supported for this configuration"); + } + + mlas_packed_buffer = IAllocator::MakeUniquePtr(allocator, packed_size); + MlasQ4GemmPackB(qtype, mlas_packed_buffer.get(), temp_float, static_cast(cols), static_cast(rows), static_cast(cols)); + + return Status::OK(); +} + +Status DirectQ4Gemm(const float* A, + const uint8_t* mlas_packed_B, + const float* bias, + float* C, + int64_t M, + int64_t N, + int64_t K, + MLAS_BLK_QUANT_TYPE qtype, + MLAS_THREADPOOL* thread_pool) { + MLAS_Q4_GEMM_DATA_PARAMS params; + params.A = A; + params.lda = static_cast(K); + params.B = mlas_packed_B; + params.Bias = bias; + params.C = C; + params.ldc = static_cast(N); + params.OutputProcessor = nullptr; + + MlasQ4GemmBatch(qtype, static_cast(M), static_cast(N), static_cast(K), 1, ¶ms, thread_pool); + return Status::OK(); +} + +template +void DequantizeBlockWithMlas(const uint8_t* quantized_data, + const TScale* scales, + int64_t block_size, + int64_t num_bits, + int64_t rows, + int64_t cols, + float* dequantized_data, + MLAS_THREADPOOL* thread_pool) { + ORT_UNUSED_PARAMETER(thread_pool); const float zero_point = num_bits == 8 ? 128.0f : 8.0f; - if (num_bits == 8) { - for (int64_t r = 0; r < rows; ++r) { - const float scale = static_cast(scales[r]); - for (int64_t c = 0; c < cols; ++c) { - // Symmetric quantization: dequantized_value = scale * (quantized_value - zero_point) - dequantized_data[r * cols + c] = scale * (static_cast(quantized_data[r * cols + c]) - zero_point); + const int64_t blocks_per_row = (block_size > 0) ? ((cols + block_size - 1) / block_size) : 1; + + if (CanUseMlasQ4Dequant(num_bits)) { + const int64_t packed_cols = (cols + 1) / 2; + + if (block_size == 0) { + for (int64_t r = 0; r < rows; ++r) { + const uint8_t* row_data = quantized_data + r * packed_cols; + float* row_output = dequantized_data + r * cols; + const float scale = static_cast(scales[r]); + + int64_t c = 0; + for (; c + 8 <= cols; c += 8) { + const uint8_t packed_val0 = row_data[(c + 0) / 2]; + const uint8_t packed_val1 = row_data[(c + 2) / 2]; + const uint8_t packed_val2 = row_data[(c + 4) / 2]; + const uint8_t packed_val3 = row_data[(c + 6) / 2]; + + row_output[c + 0] = scale * (static_cast(packed_val0 & 0x0F) - zero_point); + row_output[c + 1] = scale * (static_cast(packed_val0 >> 4) - zero_point); + row_output[c + 2] = scale * (static_cast(packed_val1 & 0x0F) - zero_point); + row_output[c + 3] = scale * (static_cast(packed_val1 >> 4) - zero_point); + row_output[c + 4] = scale * (static_cast(packed_val2 & 0x0F) - zero_point); + row_output[c + 5] = scale * (static_cast(packed_val2 >> 4) - zero_point); + row_output[c + 6] = scale * (static_cast(packed_val3 & 0x0F) - zero_point); + row_output[c + 7] = scale * (static_cast(packed_val3 >> 4) - zero_point); + } + + for (; c < cols; c += 2) { + const uint8_t packed_val = row_data[c / 2]; + const uint8_t val0 = packed_val & 0x0F; + const uint8_t val1 = packed_val >> 4; + + row_output[c] = scale * (static_cast(val0) - zero_point); + if (c + 1 < cols) { + row_output[c + 1] = scale * (static_cast(val1) - zero_point); + } + } + } + return; + } else { + for (int64_t r = 0; r < rows; ++r) { + const uint8_t* row_data = quantized_data + r * packed_cols; + float* row_output = dequantized_data + r * cols; + + for (int64_t block_start = 0; block_start < cols; block_start += block_size) { + const int64_t block_end = std::min(block_start + block_size, cols); + const int64_t block_idx = std::min(block_start / block_size, blocks_per_row - 1); + const int64_t scale_idx = r * blocks_per_row + block_idx; + const float scale = static_cast(scales[scale_idx]); + + int64_t c = block_start; + for (; c + 4 <= block_end; c += 4) { + const uint8_t packed_val0 = row_data[(c + 0) / 2]; + const uint8_t packed_val1 = row_data[(c + 2) / 2]; + + row_output[c + 0] = scale * (static_cast(packed_val0 & 0x0F) - zero_point); + row_output[c + 1] = scale * (static_cast(packed_val0 >> 4) - zero_point); + row_output[c + 2] = scale * (static_cast(packed_val1 & 0x0F) - zero_point); + row_output[c + 3] = scale * (static_cast(packed_val1 >> 4) - zero_point); + } + + for (; c < block_end; c += 2) { + const uint8_t packed_val = row_data[c / 2]; + const uint8_t val0 = packed_val & 0x0F; + const uint8_t val1 = packed_val >> 4; + + row_output[c] = scale * (static_cast(val0) - zero_point); + if (c + 1 < block_end) { + row_output[c + 1] = scale * (static_cast(val1) - zero_point); + } + } + } } + return; } - } else if (num_bits == 4) { - const int64_t packed_cols = (cols + 1) / 2; + } + + if (num_bits == 8 && block_size == 0) { for (int64_t r = 0; r < rows; ++r) { const float scale = static_cast(scales[r]); - for (int64_t c = 0; c < cols; ++c) { - const uint8_t packed_val = quantized_data[r * packed_cols + c / 2]; - // Unpack the 4-bit value. Low nibble for even columns, high nibble for odd columns. - const uint8_t quantized_val = (c % 2 == 0) ? (packed_val & 0x0F) : (packed_val >> 4); - // Symmetric quantization: dequantized_value = scale * (quantized_value - zero_point) - dequantized_data[r * cols + c] = scale * (static_cast(quantized_val) - zero_point); + const uint8_t zero_pt = static_cast(zero_point); + + MlasDequantizeLinear( + quantized_data + r * cols, + dequantized_data + r * cols, + static_cast(cols), + scale, + zero_pt); + } + } else { + if (num_bits == 8) { + for (int64_t r = 0; r < rows; ++r) { + const uint8_t* row_data = quantized_data + r * cols; + float* row_output = dequantized_data + r * cols; + + int64_t c = 0; + if (block_size > 0) { + for (int64_t block_start = 0; block_start < cols; block_start += block_size) { + const int64_t block_end = std::min(block_start + block_size, cols); + const int64_t block_idx = std::min(block_start / block_size, blocks_per_row - 1); + const int64_t scale_idx = r * blocks_per_row + block_idx; + const float scale = static_cast(scales[scale_idx]); + + for (c = block_start; c + 4 <= block_end; c += 4) { + row_output[c] = scale * (static_cast(row_data[c]) - zero_point); + row_output[c + 1] = scale * (static_cast(row_data[c + 1]) - zero_point); + row_output[c + 2] = scale * (static_cast(row_data[c + 2]) - zero_point); + row_output[c + 3] = scale * (static_cast(row_data[c + 3]) - zero_point); + } + for (; c < block_end; ++c) { + row_output[c] = scale * (static_cast(row_data[c]) - zero_point); + } + } + } else { + const float scale = static_cast(scales[r]); + for (; c + 8 <= cols; c += 8) { + row_output[c] = scale * (static_cast(row_data[c]) - zero_point); + row_output[c + 1] = scale * (static_cast(row_data[c + 1]) - zero_point); + row_output[c + 2] = scale * (static_cast(row_data[c + 2]) - zero_point); + row_output[c + 3] = scale * (static_cast(row_data[c + 3]) - zero_point); + row_output[c + 4] = scale * (static_cast(row_data[c + 4]) - zero_point); + row_output[c + 5] = scale * (static_cast(row_data[c + 5]) - zero_point); + row_output[c + 6] = scale * (static_cast(row_data[c + 6]) - zero_point); + row_output[c + 7] = scale * (static_cast(row_data[c + 7]) - zero_point); + } + for (; c < cols; ++c) { + row_output[c] = scale * (static_cast(row_data[c]) - zero_point); + } + } + } + } else if (num_bits == 4) { + const int64_t packed_cols = (cols + 1) / 2; + for (int64_t r = 0; r < rows; ++r) { + const uint8_t* row_data = quantized_data + r * packed_cols; + float* row_output = dequantized_data + r * cols; + + if (block_size > 0) { + for (int64_t block_start = 0; block_start < cols; block_start += block_size) { + const int64_t block_end = std::min(block_start + block_size, cols); + const int64_t block_idx = std::min(block_start / block_size, blocks_per_row - 1); + const int64_t scale_idx = r * blocks_per_row + block_idx; + const float scale = static_cast(scales[scale_idx]); + + for (int64_t c = block_start; c < block_end; c += 2) { + const uint8_t packed_val = row_data[c / 2]; + const uint8_t val0 = packed_val & 0x0F; + const uint8_t val1 = packed_val >> 4; + + row_output[c] = scale * (static_cast(val0) - zero_point); + if (c + 1 < block_end) { + row_output[c + 1] = scale * (static_cast(val1) - zero_point); + } + } + } + } else { + const float scale = static_cast(scales[r]); + for (int64_t c = 0; c < cols; c += 2) { + const uint8_t packed_val = row_data[c / 2]; + const uint8_t val0 = packed_val & 0x0F; + const uint8_t val1 = packed_val >> 4; + + row_output[c] = scale * (static_cast(val0) - zero_point); + if (c + 1 < cols) { + row_output[c + 1] = scale * (static_cast(val1) - zero_point); + } + } + } } } } } +template +void DequantizeBlock(const uint8_t* quantized_data, + const TScale* scales, + int64_t block_size, + int64_t num_bits, + int64_t rows, + int64_t cols, + float* dequantized_data, + MLAS_THREADPOOL* thread_pool = nullptr) { + DequantizeBlockWithMlas(quantized_data, scales, block_size, num_bits, rows, cols, dequantized_data, thread_pool); +} + template QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info) : OpKernel(op_kernel_info), @@ -63,11 +350,15 @@ QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info) ORT_ENFORCE(expert_weight_bits_ == 4 || expert_weight_bits_ == 8, "Attribute 'expert_weight_bits' must be 4 or 8."); block_size_ = op_kernel_info.GetAttrOrDefault("block_size", 0); + + if (block_size_ > 0) { + ORT_ENFORCE(block_size_ >= 16, "block_size must be >= 16 when provided."); + ORT_ENFORCE((block_size_ & (block_size_ - 1)) == 0, "block_size must be a power of 2."); + } } template Status QMoECPU::Compute(OpKernelContext* context) const { - // --- 1. Get Inputs and Attributes --- const auto* input = context->Input(0); const auto* router_probs = context->Input(1); const auto* fc1_experts_weights = context->Input(2); @@ -87,7 +378,8 @@ Status QMoECPU::Compute(OpKernelContext* context) const { fc2_experts_weights, fc2_experts_bias, fc2_scales, fc3_experts_weights, fc3_experts_bias, fc3_scales, expert_weight_bits_ == 4 ? 2 : 1, - true)); + true, + block_size_)); if (fc3_experts_weights || fc3_experts_bias || fc3_scales) { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "FC3 gating is not yet implemented on CPU for QMoE"); @@ -109,19 +401,17 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const size_t output_buffer_size = static_cast(output->Shape().Size()); const T* input_data = input->Data(); - const T* router_probs_data = router_probs->Data(); - // --- 2. Routing Logic: Assign tokens to experts --- IAllocatorUniquePtr router_logits_float_buffer; const float* router_logits_float; if constexpr (std::is_same_v) { router_logits_float_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * num_experts)); router_logits_float = router_logits_float_buffer.get(); - MlasConvertHalfToFloatBuffer(reinterpret_cast(router_probs_data), + MlasConvertHalfToFloatBuffer(reinterpret_cast(router_probs->Data()), const_cast(router_logits_float), static_cast(num_tokens * num_experts)); } else { - router_logits_float = reinterpret_cast(router_probs_data); + router_logits_float = reinterpret_cast(router_probs->Data()); } auto route_expert_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * k_)); @@ -129,57 +419,57 @@ Status QMoECPU::Compute(OpKernelContext* context) const { auto route_scale_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * k_)); float* route_scale = route_scale_ptr.get(); - // Parallelize the routing logic to improve performance for large token batches. - // Minor performance regression for single-token decoding is an acceptable trade-off - int num_routing_threads = (tp == nullptr || num_tokens < 4096) ? 1 : std::min(static_cast(num_tokens), concurrency::ThreadPool::DegreeOfParallelism(tp)); + const int max_threads = tp ? concurrency::ThreadPool::DegreeOfParallelism(tp) : 1; + const int64_t thread_divisor = std::max(1, max_threads * 4); + const int64_t min_work_per_thread = std::max(int64_t{32}, static_cast(num_tokens / thread_divisor)); + const int optimal_routing_threads = (tp == nullptr || num_tokens < min_work_per_thread) ? 1 : std::min(narrow(num_tokens / std::max(int64_t{1}, min_work_per_thread)), max_threads); + const int num_routing_threads = std::max(1, optimal_routing_threads); std::vector>> thread_local_expert_token_maps(num_routing_threads); for (auto& map : thread_local_expert_token_maps) { map.resize(static_cast(num_experts)); + for (auto& expert_tokens : map) { + expert_tokens.reserve(32); + } } concurrency::ThreadPool::TrySimpleParallelFor(tp, num_routing_threads, [&](std::ptrdiff_t thread_id) { - auto work = concurrency::ThreadPool::PartitionWork(static_cast(thread_id), num_routing_threads, static_cast(num_tokens)); + auto work = concurrency::ThreadPool::PartitionWork(narrow(thread_id), num_routing_threads, static_cast(num_tokens)); auto& local_expert_token_map = thread_local_expert_token_maps[thread_id]; - // Pre-allocate buffers for this thread to reuse, avoiding allocations inside the loop. std::vector> sorted_logits(static_cast(num_experts)); std::vector top_k_exp(static_cast(k_)); for (int64_t i = work.start; i < work.end; ++i) { const float* logits = router_logits_float + i * num_experts; - for (int64_t j = 0; j < num_experts; ++j) { - sorted_logits[static_cast(j)] = {logits[j], j}; - } - std::partial_sort(sorted_logits.begin(), sorted_logits.begin() + static_cast(k_), sorted_logits.end(), std::greater<>()); - float max_logit = -std::numeric_limits::infinity(); - for (int64_t j = 0; j < k_; ++j) { - if (sorted_logits[static_cast(j)].first > max_logit) { - max_logit = sorted_logits[static_cast(j)].first; - } + for (size_t j = 0; j < narrow(num_experts); ++j) { + sorted_logits[j] = {logits[j], j}; } + std::partial_sort(sorted_logits.begin(), sorted_logits.begin() + static_cast(k_), + sorted_logits.end(), std::greater<>()); + + float max_logit = sorted_logits[0].first; float sum_exp = 0.0f; - for (int64_t j = 0; j < k_; ++j) { - top_k_exp[static_cast(j)] = std::exp(sorted_logits[static_cast(j)].first - max_logit); - sum_exp += top_k_exp[static_cast(j)]; + for (size_t j = 0; j < narrow(k_); ++j) { + top_k_exp[j] = std::exp(sorted_logits[j].first - max_logit); + sum_exp += top_k_exp[j]; } - float scale = (sum_exp == 0.0f) ? 0.0f : (1.0f / sum_exp); - for (int64_t j = 0; j < k_; ++j) { - int64_t expert_idx = sorted_logits[static_cast(j)].second; - int64_t route_idx = i * k_ + j; - route_expert[route_idx] = static_cast(expert_idx); - route_scale[route_idx] = top_k_exp[static_cast(j)] * scale; - if (route_scale[route_idx] > 0.0f) { + const float inv_sum = (sum_exp == 0.0f) ? 0.0f : (1.0f / sum_exp); + for (size_t j = 0; j < narrow(k_); ++j) { + int64_t expert_idx = sorted_logits[j].second; + int64_t route_idx = i * k_ + narrow(j); + route_expert[route_idx] = narrow(expert_idx); + route_scale[route_idx] = top_k_exp[j] * inv_sum; + if (route_scale[route_idx] > 1e-8f) { // Use small threshold to avoid zero weights local_expert_token_map[static_cast(expert_idx)].push_back(route_idx); } } } }); - // Merge the maps from each thread into a single global map. std::vector> expert_token_map(static_cast(num_experts)); for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { size_t total_tokens_for_expert = 0; @@ -187,18 +477,17 @@ Status QMoECPU::Compute(OpKernelContext* context) const { total_tokens_for_expert += thread_local_expert_token_maps[t][static_cast(expert_idx)].size(); } expert_token_map[static_cast(expert_idx)].reserve(total_tokens_for_expert); - } - for (int t = 0; t < num_routing_threads; ++t) { - for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + for (int t = 0; t < num_routing_threads; ++t) { auto& local_tokens = thread_local_expert_token_maps[t][static_cast(expert_idx)]; if (!local_tokens.empty()) { - expert_token_map[static_cast(expert_idx)].insert(expert_token_map[static_cast(expert_idx)].end(), local_tokens.begin(), local_tokens.end()); + expert_token_map[static_cast(expert_idx)].insert( + expert_token_map[static_cast(expert_idx)].end(), + local_tokens.begin(), local_tokens.end()); } } } - // --- 3. Parallel Expert Computation --- IAllocatorUniquePtr input_float_buffer; const float* input_float; if constexpr (std::is_same_v) { @@ -211,118 +500,426 @@ Status QMoECPU::Compute(OpKernelContext* context) const { input_float = reinterpret_cast(input_data); } - int num_expert_threads = (tp == nullptr) ? 1 : std::min(static_cast(num_experts), concurrency::ThreadPool::DegreeOfParallelism(tp)); + const int max_expert_threads = tp ? concurrency::ThreadPool::DegreeOfParallelism(tp) : 1; + const int64_t total_expert_work = std::accumulate(expert_token_map.begin(), expert_token_map.end(), 0LL, + [](int64_t sum, const std::vector& tokens) { return sum + static_cast(tokens.size()); }); + const int64_t expert_thread_divisor = std::max(1, max_expert_threads * 8); + const int64_t min_expert_work_per_thread = std::max(int64_t{16}, total_expert_work / expert_thread_divisor); + + int num_expert_threads = (tp == nullptr || total_expert_work < min_expert_work_per_thread) ? 1 : std::min(narrow(total_expert_work / std::max(int64_t{1}, min_expert_work_per_thread)), std::min(narrow(num_experts), max_expert_threads)); if (num_expert_threads == 0) num_expert_threads = 1; + auto thread_local_outputs_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_expert_threads) * output_buffer_size); float* thread_local_outputs = thread_local_outputs_ptr.get(); - memset(thread_local_outputs, 0, static_cast(num_expert_threads) * output_buffer_size * sizeof(float)); + std::memset(thread_local_outputs, 0, static_cast(num_expert_threads) * output_buffer_size * sizeof(float)); - // Pre-calculate workspace size per thread to avoid allocations inside the loop size_t max_tokens_per_expert = 0; for (const auto& tokens : expert_token_map) { - if (tokens.size() > max_tokens_per_expert) { - max_tokens_per_expert = tokens.size(); - } + max_tokens_per_expert = std::max(max_tokens_per_expert, tokens.size()); } - const size_t A1_size = static_cast(max_tokens_per_expert * hidden_size); - const size_t C1_size = static_cast(max_tokens_per_expert * fc1_out_features); - const size_t A2_size = static_cast(max_tokens_per_expert * inter_size); - const size_t C2_size = static_cast(max_tokens_per_expert * hidden_size); - const size_t B1_dequant_size = static_cast(fc1_out_features * hidden_size); - const size_t B2_dequant_size = static_cast(hidden_size * inter_size); - const size_t bias1_size = static_cast(fc1_out_features); - const size_t bias2_size = static_cast(hidden_size); + const auto align_size = [](size_t size) -> size_t { + return (size + 63) & ~63; + }; + + const size_t A1_size = align_size(static_cast(max_tokens_per_expert) * static_cast(hidden_size)); + const size_t C1_size = align_size(static_cast(max_tokens_per_expert) * static_cast(fc1_out_features)); + const size_t A2_size = align_size(static_cast(max_tokens_per_expert) * static_cast(inter_size)); + const size_t C2_size = align_size(static_cast(max_tokens_per_expert) * static_cast(hidden_size)); + const size_t B1_dequant_size = align_size(static_cast(fc1_out_features) * static_cast(hidden_size)); + const size_t B2_dequant_size = align_size(static_cast(hidden_size) * static_cast(inter_size)); + + const size_t workspace_elements_per_thread = A1_size + C1_size + A2_size + C2_size + + B1_dequant_size + B2_dequant_size; - const size_t workspace_elements_per_thread = A1_size + C1_size + A2_size + C2_size + B1_dequant_size + B2_dequant_size + bias1_size + bias2_size; auto workspace_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_expert_threads) * workspace_elements_per_thread); float* workspace = workspace_ptr.get(); + auto bias_conversion_buffers_ptr = IAllocator::MakeUniquePtr(allocator, + static_cast(num_expert_threads) * (static_cast(fc1_out_features) + static_cast(hidden_size))); + float* bias_conversion_buffers = bias_conversion_buffers_ptr.get(); + + const auto& fc1_scales_dims = fc1_scales->Shape().GetDims(); + const auto& fc2_scales_dims = fc2_scales->Shape().GetDims(); + const bool is_fc1_block_wise = (fc1_scales_dims.size() == 3 && fc1_scales_dims[2] > 1); + const bool is_fc2_block_wise = (fc2_scales_dims.size() == 3 && fc2_scales_dims[2] > 1); + + const uint8_t* fc1_weights_data = fc1_experts_weights->Data(); + const uint8_t* fc2_weights_data = fc2_experts_weights->Data(); + const T* fc1_scales_data = fc1_scales->Data(); + const T* fc2_scales_data = fc2_scales->Data(); + const T* fc1_bias_data = fc1_experts_bias ? fc1_experts_bias->Data() : nullptr; + const T* fc2_bias_data = fc2_experts_bias ? fc2_experts_bias->Data() : nullptr; + + const int64_t pack_unit = (8 / expert_weight_bits_); + const int64_t fc1_packed_cols = (hidden_size + pack_unit - 1) / pack_unit; + const int64_t fc2_packed_cols = (inter_size + pack_unit - 1) / pack_unit; + const bool has_fc1_bias = (fc1_bias_data != nullptr); + const bool has_fc2_bias = (fc2_bias_data != nullptr); + + std::vector> expert_workload; + size_t total_work = 0; + + for (int64_t i = 0; i < num_experts; ++i) { + const size_t token_count = expert_token_map[static_cast(i)].size(); + if (token_count > 0) { + expert_workload.emplace_back(i, token_count); + total_work += token_count; + } + } + + if (total_work < 48) { + num_expert_threads = 1; + } else if (total_work < 192) { + num_expert_threads = std::min(num_expert_threads, 2); + } else if (total_work < 512) { + num_expert_threads = std::min(num_expert_threads, 4); + } + + std::sort(expert_workload.begin(), expert_workload.end(), + [](const auto& a, const auto& b) { return a.second > b.second; }); + + std::vector> expert_batches(num_expert_threads); + size_t thread_idx = 0; + for (const auto& work : expert_workload) { + expert_batches[thread_idx].push_back(work.first); + thread_idx = (thread_idx + 1) % static_cast(num_expert_threads); + } + concurrency::ThreadPool::TrySimpleParallelFor(tp, num_expert_threads, [&](std::ptrdiff_t thread_id_pd) { - int thread_id = static_cast(thread_id_pd); - auto work = concurrency::ThreadPool::PartitionWork(thread_id, num_expert_threads, static_cast(num_experts)); + const int thread_id = narrow(thread_id_pd); + const auto& expert_batch = expert_batches[static_cast(thread_id)]; float* thread_workspace = workspace + static_cast(thread_id) * workspace_elements_per_thread; - for (int64_t expert_idx = work.start; expert_idx < work.end; ++expert_idx) { + float* thread_bias1_buffer = bias_conversion_buffers + static_cast(thread_id) * (static_cast(fc1_out_features) + static_cast(hidden_size)); + float* thread_bias2_buffer = thread_bias1_buffer + static_cast(fc1_out_features); + + for (int64_t expert_idx : expert_batch) { const auto& routes = expert_token_map[static_cast(expert_idx)]; if (routes.empty()) { continue; } - const int64_t num_expert_tokens = routes.size(); + const int64_t num_expert_tokens = static_cast(routes.size()); - // Partition the workspace for the current expert float* A1 = thread_workspace; - float* C1 = A1 + num_expert_tokens * hidden_size; - float* A2 = C1 + num_expert_tokens * fc1_out_features; - float* C2 = A2 + num_expert_tokens * inter_size; - float* B1_dequant = C2 + num_expert_tokens * hidden_size; - float* B2_dequant = B1_dequant + fc1_out_features * hidden_size; - float* bias1_float = B2_dequant + hidden_size * inter_size; - float* bias2_float = bias1_float + fc1_out_features; - - // --- Gather input tokens for the current expert --- - for (int64_t i = 0; i < num_expert_tokens; ++i) { - const int64_t token_idx = routes[static_cast(i)] / k_; - memcpy(A1 + i * hidden_size, - input_float + token_idx * hidden_size, - static_cast(hidden_size) * sizeof(float)); + float* C1 = A1 + A1_size; + float* A2 = C1 + C1_size; + float* C2 = A2 + A2_size; + float* B1_dequant = C2 + C2_size; + float* B2_dequant = B1_dequant + B1_dequant_size; + + const int64_t dynamic_block_size = GetOptimalBlockSize(num_expert_tokens, tp ? concurrency::ThreadPool::DegreeOfParallelism(tp) : 1); + const int64_t num_blocks = (num_expert_tokens + dynamic_block_size - 1) / dynamic_block_size; + + if (num_expert_tokens >= 8 && num_blocks > 1 && tp != nullptr) { + concurrency::ThreadPool::TrySimpleParallelFor(tp, narrow(num_blocks), [&](std::ptrdiff_t block_idx) { + const int64_t start_idx = block_idx * dynamic_block_size; + const int64_t end_idx = std::min(start_idx + dynamic_block_size, num_expert_tokens); + + for (int64_t i = start_idx; i < end_idx; ++i) { + const int64_t token_idx = routes[static_cast(i)] / k_; + const float* src = input_float + token_idx * hidden_size; + float* dst = A1 + i * hidden_size; + + std::memcpy(dst, src, static_cast(hidden_size) * sizeof(float)); + } + }); + } else { + for (int64_t i = 0; i < num_expert_tokens; ++i) { + const int64_t token_idx = routes[static_cast(i)] / k_; + const float* src = input_float + token_idx * hidden_size; + float* dst = A1 + i * hidden_size; + + if (ShouldUseMemcpy(hidden_size)) { + std::memcpy(dst, src, static_cast(hidden_size) * sizeof(float)); + } else { + const size_t unroll_factor = narrow(GetUnrollFactor(hidden_size)); + size_t j = 0; + for (; j + unroll_factor <= narrow(hidden_size); j += unroll_factor) { + for (size_t k = 0; k < unroll_factor; ++k) { + dst[j + k] = src[j + k]; + } + } + for (; j < narrow(hidden_size); ++j) { + dst[j] = src[j]; + } + } + } + } + + const T* fc1_scales_ptr; + + if (is_fc1_block_wise) { + const int64_t fc1_blocks_per_row = fc1_scales_dims[2]; + fc1_scales_ptr = fc1_scales_data + expert_idx * fc1_out_features * fc1_blocks_per_row; + } else { + fc1_scales_ptr = fc1_scales_data + expert_idx * fc1_out_features; } - // --- FC1 GEMM (X * W1^T) --- - DequantizeBlock(fc1_experts_weights->Data() + expert_idx * fc1_out_features * (hidden_size / (8 / expert_weight_bits_)), - fc1_scales->Data() + expert_idx * fc1_out_features * (block_size_ > 0 ? hidden_size / block_size_ : 1), - block_size_, expert_weight_bits_, - fc1_out_features, hidden_size, B1_dequant); + const int64_t dequant_block_size = GetDequantBlockSize(fc1_out_features, num_expert_tokens); + const int64_t num_dequant_blocks = (fc1_out_features + dequant_block_size - 1) / dequant_block_size; + + const size_t m = static_cast(num_expert_tokens); + const size_t n = static_cast(fc1_out_features); + const size_t k = static_cast(hidden_size); + + MLAS_BLK_QUANT_TYPE q_type; + bool use_direct_q4_gemm = CanUseMlasQ4Gemm(expert_weight_bits_, is_fc1_block_wise ? block_size_ : 0, + fc1_out_features, hidden_size, q_type); + bool fc1_used_direct_q4 = false; + bool fc1_bias_handled_by_q4_gemm = false; + + if (use_direct_q4_gemm) { + IAllocatorUniquePtr mlas_packed_fc1; + Status convert_status = ConvertToMlasQ4Format( + fc1_weights_data + expert_idx * fc1_out_features * fc1_packed_cols, + fc1_scales_ptr, + is_fc1_block_wise ? block_size_ : 0, + expert_weight_bits_, + fc1_out_features, + hidden_size, + q_type, + allocator, + mlas_packed_fc1); + + if (convert_status.IsOK()) { + float* fc1_bias_float = nullptr; + IAllocatorUniquePtr fc1_bias_buffer; + + if (has_fc1_bias) { + const T* B1_bias = fc1_bias_data + expert_idx * fc1_out_features; + fc1_bias_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(fc1_out_features)); + fc1_bias_float = fc1_bias_buffer.get(); + + if constexpr (std::is_same_v) { + MlasConvertHalfToFloatBuffer(reinterpret_cast(B1_bias), fc1_bias_float, static_cast(fc1_out_features)); + } else { + for (int64_t i = 0; i < fc1_out_features; ++i) { + fc1_bias_float[i] = static_cast(B1_bias[i]); + } + } + } + + Status gemm_status = DirectQ4Gemm(A1, mlas_packed_fc1.get(), fc1_bias_float, C1, + num_expert_tokens, fc1_out_features, hidden_size, q_type, tp); + + if (gemm_status.IsOK()) { + fc1_used_direct_q4 = true; + goto fc1_gemm_done; + } + } + // If direct Q4 GEMM failed, fall back to traditional approach + } + + // Traditional approach: dequantize + regular GEMM + if (num_dequant_blocks > 1 && fc1_out_features >= 32) { + concurrency::ThreadPool::TrySimpleParallelFor(tp, narrow(num_dequant_blocks), [&](std::ptrdiff_t block_idx) { + const int64_t start_row = block_idx * dequant_block_size; + const int64_t end_row = std::min(start_row + dequant_block_size, fc1_out_features); + const auto offset = expert_idx * fc1_out_features * fc1_packed_cols + start_row * fc1_packed_cols; + DequantizeBlock(fc1_weights_data + offset, + fc1_scales_ptr + (is_fc1_block_wise ? start_row * fc1_scales_dims[2] : start_row), + is_fc1_block_wise ? block_size_ : 0, expert_weight_bits_, + end_row - start_row, hidden_size, B1_dequant + start_row * hidden_size, tp); + }); + } else { + DequantizeBlock(fc1_weights_data + expert_idx * fc1_out_features * fc1_packed_cols, + fc1_scales_ptr, + is_fc1_block_wise ? block_size_ : 0, expert_weight_bits_, + fc1_out_features, hidden_size, B1_dequant, tp); + } MlasGemm(CblasNoTrans, CblasTrans, - static_cast(num_expert_tokens), static_cast(fc1_out_features), static_cast(hidden_size), - 1.0f, A1, static_cast(hidden_size), - B1_dequant, static_cast(hidden_size), - 0.0f, C1, static_cast(fc1_out_features), - nullptr); - - const T* B1_bias = (fc1_experts_bias) ? fc1_experts_bias->Data() + expert_idx * fc1_out_features : nullptr; - if (B1_bias) { + m, n, k, + 1.0f, A1, k, + B1_dequant, k, + 0.0f, C1, n, + tp); + + fc1_bias_handled_by_q4_gemm = fc1_used_direct_q4 && has_fc1_bias; + if (has_fc1_bias && !fc1_bias_handled_by_q4_gemm) { + const T* B1_bias = fc1_bias_data + expert_idx * fc1_out_features; if constexpr (std::is_same_v) { - MlasConvertHalfToFloatBuffer(reinterpret_cast(B1_bias), bias1_float, static_cast(fc1_out_features)); + MlasConvertHalfToFloatBuffer(reinterpret_cast(B1_bias), thread_bias1_buffer, static_cast(fc1_out_features)); } else { - memcpy(bias1_float, B1_bias, static_cast(fc1_out_features) * sizeof(float)); + if (ShouldUseMemcpy(fc1_out_features)) { + std::memcpy(thread_bias1_buffer, B1_bias, static_cast(fc1_out_features) * sizeof(float)); + } else { + const size_t unroll_factor = static_cast(GetUnrollFactor(fc1_out_features)); + size_t j = 0; + for (; j + unroll_factor <= static_cast(fc1_out_features); j += unroll_factor) { + for (size_t loop_k = 0; loop_k < unroll_factor; ++loop_k) { + thread_bias1_buffer[j + loop_k] = static_cast(B1_bias[j + loop_k]); + } + } + for (; j < static_cast(fc1_out_features); ++j) { + thread_bias1_buffer[j] = static_cast(B1_bias[j]); + } + } } + for (int64_t i = 0; i < num_expert_tokens; ++i) { - for (int64_t j = 0; j < fc1_out_features; ++j) { - C1[i * fc1_out_features + j] += bias1_float[j]; + float* C1_row = C1 + i * fc1_out_features; + const size_t unroll_factor = static_cast(GetUnrollFactor(fc1_out_features)); + + size_t j = 0; + for (; j + unroll_factor <= static_cast(fc1_out_features); j += unroll_factor) { + for (size_t loop_k = 0; loop_k < unroll_factor; ++loop_k) { + C1_row[j + loop_k] += thread_bias1_buffer[j + loop_k]; + } + } + for (; j < static_cast(fc1_out_features); ++j) { + C1_row[j] += thread_bias1_buffer[j]; } } } - // --- Activation --- - for (int64_t i = 0; i < num_expert_tokens; ++i) { - const float* C1_token = C1 + i * fc1_out_features; - float* A2_token = A2 + i * inter_size; - ApplySwiGLUActivation(C1_token, A2_token, inter_size, true, activation_alpha_, activation_beta_, swiglu_limit_); + fc1_gemm_done: + + const int64_t activation_threshold = std::max(int64_t{4}, 256 / std::max(int64_t{1}, inter_size)); + if (num_expert_tokens >= activation_threshold && tp != nullptr) { + const int64_t activation_block_size = std::max(int64_t{1}, std::min(int64_t{64}, activation_threshold)); + const int64_t num_activation_blocks = (num_expert_tokens + activation_block_size - 1) / activation_block_size; + + if (num_activation_blocks > 1) { + concurrency::ThreadPool::TrySimpleParallelFor(tp, narrow(num_activation_blocks), [&](std::ptrdiff_t block_idx) { + const int64_t start_token = block_idx * activation_block_size; + const int64_t end_token = std::min(start_token + activation_block_size, num_expert_tokens); + + for (int64_t i = start_token; i < end_token; ++i) { + const float* C1_token = C1 + i * fc1_out_features; + float* A2_token = A2 + i * inter_size; + ApplySwiGLUActivation(C1_token, A2_token, inter_size, true, activation_alpha_, activation_beta_, swiglu_limit_); + } + }); + } else { + for (int64_t i = 0; i < num_expert_tokens; ++i) { + const float* C1_token = C1 + i * fc1_out_features; + float* A2_token = A2 + i * inter_size; + ApplySwiGLUActivation(C1_token, A2_token, inter_size, true, activation_alpha_, activation_beta_, swiglu_limit_); + } + } + } else { + for (int64_t i = 0; i < num_expert_tokens; ++i) { + const float* C1_token = C1 + i * fc1_out_features; + float* A2_token = A2 + i * inter_size; + ApplySwiGLUActivation(C1_token, A2_token, inter_size, true, activation_alpha_, activation_beta_, swiglu_limit_); + } + } + + const T* fc2_scales_ptr; + + if (is_fc2_block_wise) { + const int64_t fc2_blocks_per_row = fc2_scales_dims[2]; + fc2_scales_ptr = fc2_scales_data + expert_idx * hidden_size * fc2_blocks_per_row; + } else { + fc2_scales_ptr = fc2_scales_data + expert_idx * hidden_size; } - // --- FC2 GEMM (A2 * W2^T) --- - DequantizeBlock(fc2_experts_weights->Data() + expert_idx * hidden_size * (inter_size / (8 / expert_weight_bits_)), - fc2_scales->Data() + expert_idx * hidden_size * (block_size_ > 0 ? inter_size / block_size_ : 1), - block_size_, expert_weight_bits_, - hidden_size, inter_size, B2_dequant); + const int64_t fc2_dequant_block_size = GetDequantBlockSize(hidden_size, num_expert_tokens); + const int64_t num_fc2_dequant_blocks = (hidden_size + fc2_dequant_block_size - 1) / fc2_dequant_block_size; + + const size_t m2 = static_cast(num_expert_tokens); + const size_t n2 = static_cast(hidden_size); + const size_t k2 = static_cast(inter_size); + + MLAS_BLK_QUANT_TYPE q_type2; + bool use_direct_q4_gemm_fc2 = CanUseMlasQ4Gemm(expert_weight_bits_, is_fc2_block_wise ? block_size_ : 0, + hidden_size, inter_size, q_type2); + bool fc2_used_direct_q4 = false; + + if (use_direct_q4_gemm_fc2) { + IAllocatorUniquePtr mlas_packed_fc2; + Status convert_status = ConvertToMlasQ4Format( + fc2_weights_data + expert_idx * hidden_size * fc2_packed_cols, + fc2_scales_ptr, + is_fc2_block_wise ? block_size_ : 0, + expert_weight_bits_, + hidden_size, + inter_size, + q_type2, + allocator, + mlas_packed_fc2); + + if (convert_status.IsOK()) { + float* fc2_bias_float = nullptr; + IAllocatorUniquePtr fc2_bias_buffer; + + if (has_fc2_bias) { + const T* B2_bias = fc2_bias_data + expert_idx * hidden_size; + fc2_bias_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(hidden_size)); + fc2_bias_float = fc2_bias_buffer.get(); + + if constexpr (std::is_same_v) { + MlasConvertHalfToFloatBuffer(reinterpret_cast(B2_bias), fc2_bias_float, static_cast(hidden_size)); + } else { + for (int64_t i = 0; i < hidden_size; ++i) { + fc2_bias_float[i] = static_cast(B2_bias[i]); + } + } + } + + Status gemm_status = DirectQ4Gemm(A2, mlas_packed_fc2.get(), fc2_bias_float, C2, + num_expert_tokens, hidden_size, inter_size, q_type2, tp); + + if (gemm_status.IsOK()) { + fc2_used_direct_q4 = true; + goto fc2_gemm_done; + } + } + + // If direct Q4 GEMM failed, fall back to traditional approach + } + + // Traditional approach: dequantize + regular GEMM + if (num_fc2_dequant_blocks > 1 && hidden_size >= 32) { + concurrency::ThreadPool::TrySimpleParallelFor(tp, narrow(num_fc2_dequant_blocks), [&](std::ptrdiff_t block_idx) { + const int64_t start_row = block_idx * fc2_dequant_block_size; + const int64_t end_row = std::min(start_row + fc2_dequant_block_size, hidden_size); + const auto offset = expert_idx * hidden_size * fc2_packed_cols + start_row * fc2_packed_cols; + DequantizeBlock(fc2_weights_data + offset, + fc2_scales_ptr + (is_fc2_block_wise ? start_row * fc2_scales_dims[2] : start_row), + is_fc2_block_wise ? block_size_ : 0, expert_weight_bits_, + end_row - start_row, inter_size, B2_dequant + start_row * inter_size, tp); + }); + } else { + DequantizeBlock(fc2_weights_data + expert_idx * hidden_size * fc2_packed_cols, + fc2_scales_ptr, + is_fc2_block_wise ? block_size_ : 0, expert_weight_bits_, + hidden_size, inter_size, B2_dequant, tp); + } MlasGemm(CblasNoTrans, CblasTrans, - static_cast(num_expert_tokens), static_cast(hidden_size), static_cast(inter_size), - 1.0f, A2, static_cast(inter_size), - B2_dequant, static_cast(inter_size), - 0.0f, C2, static_cast(hidden_size), - nullptr); - - const T* B2_bias = (fc2_experts_bias) ? fc2_experts_bias->Data() + expert_idx * hidden_size : nullptr; - if (B2_bias) { + m2, n2, k2, + 1.0f, A2, k2, + B2_dequant, k2, + 0.0f, C2, n2, + tp); + + fc2_gemm_done: + + bool fc2_bias_handled_by_q4_gemm = fc2_used_direct_q4 && has_fc2_bias; + if (has_fc2_bias && !fc2_bias_handled_by_q4_gemm) { + const T* B2_bias = fc2_bias_data + expert_idx * hidden_size; if constexpr (std::is_same_v) { - MlasConvertHalfToFloatBuffer(reinterpret_cast(B2_bias), bias2_float, static_cast(hidden_size)); + MlasConvertHalfToFloatBuffer(reinterpret_cast(B2_bias), thread_bias2_buffer, static_cast(hidden_size)); } else { - memcpy(bias2_float, B2_bias, static_cast(hidden_size) * sizeof(float)); + if (ShouldUseMemcpy(hidden_size)) { + std::memcpy(thread_bias2_buffer, B2_bias, static_cast(hidden_size) * sizeof(float)); + } else { + const size_t unroll_factor = narrow(GetUnrollFactor(hidden_size)); + size_t j = 0; + for (; j + unroll_factor <= narrow(hidden_size); j += unroll_factor) { + for (size_t loop_k = 0; loop_k < unroll_factor; ++loop_k) { + thread_bias2_buffer[j + loop_k] = static_cast(B2_bias[j + loop_k]); + } + } + for (; j < narrow(hidden_size); ++j) { + thread_bias2_buffer[j] = static_cast(B2_bias[j]); + } + } } } @@ -331,28 +928,89 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const int64_t token_idx = route_idx / k_; const float weight = route_scale[route_idx]; + if (token_idx < 0 || token_idx >= num_tokens) continue; + const size_t buffer_offset = static_cast(token_idx) * static_cast(hidden_size); - if (buffer_offset + static_cast(hidden_size) > output_buffer_size) { - // Skip this token to prevent buffer overflow - continue; - } + if (buffer_offset + static_cast(hidden_size) > output_buffer_size) continue; float* dest = thread_local_outputs + static_cast(thread_id) * output_buffer_size + buffer_offset; const float* src = C2 + i * hidden_size; - for (int64_t j = 0; j < hidden_size; ++j) { - dest[j] += weight * (src[j] + (B2_bias ? bias2_float[j] : 0.0f)); + + if (has_fc2_bias && !fc2_bias_handled_by_q4_gemm) { + const size_t unroll_factor = narrow(GetUnrollFactor(hidden_size)); + size_t j = 0; + for (; j + unroll_factor <= narrow(hidden_size); j += unroll_factor) { + for (size_t loop_k = 0; loop_k < unroll_factor; ++loop_k) { + dest[j + loop_k] += weight * (src[j + loop_k] + thread_bias2_buffer[j + loop_k]); + } + } + for (; j < narrow(hidden_size); ++j) { + dest[j] += weight * (src[j] + thread_bias2_buffer[j]); + } + } else { + const size_t unroll_factor = narrow(GetUnrollFactor(hidden_size)); + size_t j = 0; + for (; j + unroll_factor <= narrow(hidden_size); j += unroll_factor) { + for (size_t loop_k = 0; loop_k < unroll_factor; ++loop_k) { + dest[j + loop_k] += weight * src[j + loop_k]; + } + } + for (; j < narrow(hidden_size); ++j) { + dest[j] += weight * src[j]; + } } } } }); - // --- 4. Final Reduction (accumulate expert outputs to a float buffer) --- auto accumulate = [&](float* buffer) { - memset(buffer, 0, output_buffer_size * sizeof(float)); - for (int i = 0; i < num_expert_threads; ++i) { - const size_t thread_offset = static_cast(i) * output_buffer_size; - for (size_t j = 0; j < output_buffer_size; ++j) { - buffer[j] += thread_local_outputs[thread_offset + j]; + std::memset(buffer, 0, output_buffer_size * sizeof(float)); + + const int max_acc_threads = tp ? concurrency::ThreadPool::DegreeOfParallelism(tp) : 1; + const size_t acc_thread_divisor = std::max(size_t{1}, static_cast(max_acc_threads) * 8); + const size_t min_elements_per_thread = std::max(size_t{32}, output_buffer_size / acc_thread_divisor); + const int optimal_acc_threads = (tp == nullptr || output_buffer_size < min_elements_per_thread) ? 1 : std::min(narrow(output_buffer_size / std::max(size_t{1}, min_elements_per_thread)), max_acc_threads); + const int num_acc_threads = std::max(1, optimal_acc_threads); + + if (num_acc_threads > 1) { + concurrency::ThreadPool::TrySimpleParallelFor(tp, num_acc_threads, [&](std::ptrdiff_t acc_thread_id) { + const size_t elements_per_thread = output_buffer_size / static_cast(num_acc_threads); + const size_t start_idx = static_cast(acc_thread_id) * elements_per_thread; + const size_t end_idx = (acc_thread_id == num_acc_threads - 1) ? output_buffer_size : start_idx + elements_per_thread; + + for (int i = 0; i < num_expert_threads; ++i) { + const size_t thread_offset = static_cast(i) * output_buffer_size; + const float* src = thread_local_outputs + thread_offset + start_idx; + float* dst = buffer + start_idx; + + size_t j = 0; + const size_t chunk_size = end_idx - start_idx; + const size_t unroll_factor = static_cast(GetUnrollFactor(static_cast(chunk_size))); + for (; j + unroll_factor <= chunk_size; j += unroll_factor) { + for (size_t loop_k = 0; loop_k < unroll_factor; ++loop_k) { + dst[j + loop_k] += src[j + loop_k]; + } + } + for (; j < chunk_size; ++j) { + dst[j] += src[j]; + } + } + }); + } else { + for (int i = 0; i < num_expert_threads; ++i) { + const size_t thread_offset = static_cast(i) * output_buffer_size; + const float* src = thread_local_outputs + thread_offset; + + size_t j = 0; + const size_t unroll_factor = narrow(GetUnrollFactor(narrow(output_buffer_size))); + for (; j + unroll_factor <= output_buffer_size; j += unroll_factor) { + for (size_t loop_k = 0; loop_k < unroll_factor; ++loop_k) { + buffer[j + loop_k] += src[j + loop_k]; + } + } + for (; j < output_buffer_size; ++j) { + buffer[j] += src[j]; + } } } }; @@ -362,18 +1020,16 @@ Status QMoECPU::Compute(OpKernelContext* context) const { float* final_output_float = final_output_float_ptr.get(); accumulate(final_output_float); - // --- 5. Convert final float buffer to output type T --- MlasConvertFloatToHalfBuffer(final_output_float, reinterpret_cast(output->MutableData()), static_cast(output_buffer_size)); - } else { // T is float + } else { accumulate(output->MutableData()); } return Status::OK(); } -// Explicit template instantiation template QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info); template Status QMoECPU::Compute(OpKernelContext* context) const; template QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info); diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_utils.cc b/onnxruntime/contrib_ops/cpu/moe/moe_utils.cc index 2c59210bfabd4..5a3c5d1dd0364 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_utils.cc +++ b/onnxruntime/contrib_ops/cpu/moe/moe_utils.cc @@ -37,10 +37,18 @@ void ApplySwiGLUActivation(const float* input_data, float* output_data, int64_t gate_val = std::min(gate_val, clamp_limit); linear_val = std::clamp(linear_val, -clamp_limit, clamp_limit); + // Use numerically stable sigmoid computation (matches CUDA kernel behavior) float sigmoid_arg = activation_alpha * gate_val; - float sigmoid_out = 1.0f / (1.0f + std::exp(-sigmoid_arg)); - float swish_out = gate_val * sigmoid_out; + float sigmoid_out; + if (sigmoid_arg > 0) { + float exp_neg = std::exp(-sigmoid_arg); + sigmoid_out = 1.0f / (1.0f + exp_neg); + } else { + float exp_pos = std::exp(sigmoid_arg); + sigmoid_out = exp_pos / (1.0f + exp_pos); + } + float swish_out = gate_val * sigmoid_out; output_data[i] = swish_out * (linear_val + activation_beta); } } else { diff --git a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc index 93d802ca05b42..167b2af946183 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc +++ b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc @@ -77,7 +77,8 @@ Status ShardedMoE::ComputeInternal(OpKernelContext* context) const { fc2_experts_weights, fc2_experts_bias_optional, nullptr, fc3_experts_weights_optional, fc3_experts_bias_optional, nullptr, 1, // no quantization so pack size is 1 - activation_type_ == ort_fastertransformer::ActivationType::SwiGLU)); + activation_type_ == ort_fastertransformer::ActivationType::SwiGLU, + 0)); // no block-wise quantization for sharded MoE ORT_RETURN_IF_NOT(moe_params.num_experts % nccl_->Size() == 0, "num_experts should be divisible by world_size"); diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.cc b/onnxruntime/contrib_ops/cuda/moe/moe.cc index a5b9d483d5ad1..e5a064d59e360 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe.cc +++ b/onnxruntime/contrib_ops/cuda/moe/moe.cc @@ -45,7 +45,8 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { fc2_experts_weights, fc2_experts_bias_optional, nullptr, fc3_experts_weights_optional, fc3_experts_bias_optional, nullptr, 1, // no quantization so pack size is 1 - activation_type_ == ort_fastertransformer::ActivationType::SwiGLU)); + activation_type_ == ort_fastertransformer::ActivationType::SwiGLU, + 0)); // no block-wise quantization for regular MoE using CudaT = typename OrtToCudaType::type; auto stream = context->GetComputeStream(); diff --git a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc index dcf32bb3c5ae4..931b8ac09aa49 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc @@ -150,7 +150,8 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const { fc2_experts_weights, fc2_experts_bias_optional, fc2_scales, fc3_experts_weights_optional, fc3_experts_bias_optional, fc3_scales_optional, expert_weight_bits_ == 4 ? 2 : 1, - activation_type_ == ort_fastertransformer::ActivationType::SwiGLU)); + activation_type_ == ort_fastertransformer::ActivationType::SwiGLU, + 0)); // CUDA doesn't support block-wise quantization yet #if defined(__GNUC__) #pragma GCC diagnostic push diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 98cc2158eb0d0..01ba492eb166e 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -226,13 +226,22 @@ Status SessionState::PopulateKernelCreateInfo(const KernelRegistryManager& kerne for (auto& node : graph_.Nodes()) { const KernelCreateInfo* kci = nullptr; auto status = kernel_registry_manager.SearchKernelRegistry(node, logger_, &kci); - if (!status.IsOK() && saving_ort_format) { - // if we didn't find the kernel and are saving to ORT format an EP that compiles nodes is enabled. - // in that case we assigned the node to that EP but do not compile it into a fused node. - // this keeps the original node and prevents level 2 and level 3 optimizers from modifying it. - // we now revert to the CPU EP kernel as a fallback. - // at runtime when the model is loaded in a minimal build, the compiling EP will replace this node if possible. - // if that's not possible for some reason we can fallback to the CPU EP implementation. + + // There are two cases where we allow fallback to CPU EP kernels: + // + // 1. if we didn't find the kernel and are saving to ORT format an EP that compiles nodes is enabled. + // in that case we assigned the node to that EP but do not compile it into a fused node. + // this keeps the original node and prevents level 2 and level 3 optimizers from modifying it. + // we now revert to the CPU EP kernel as a fallback. + // at runtime when the model is loaded in a minimal build, the compiling EP will replace this node if possible. + // if that's not possible for some reason we can fallback to the CPU EP implementation. + // + // 2. If the node is a memcpy node. + // EPs may provide their own memcpy kernels. The CPU EP provides a generic version to fall back to if the EP does + // not provide one. + const bool allow_cpu_ep_kernel_fallback = saving_ort_format || utils::IsMemcpyNode(node); + + if (!status.IsOK() && allow_cpu_ep_kernel_fallback) { node.SetExecutionProviderType(kCpuExecutionProvider); status = kernel_registry_manager.SearchKernelRegistry(node, logger_, &kci); } diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index 2c0a51f0bfdbc..ca64c7c7cae89 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -46,22 +46,13 @@ void DestroyStrings(void* p_data, int64_t elements) { ptr[i].~string(); } -bool ProviderIsCpuBased(const std::string& provider_type) { - return provider_type == onnxruntime::kCpuExecutionProvider || - provider_type == onnxruntime::kDnnlExecutionProvider || - provider_type == onnxruntime::kVitisAIExecutionProvider || - provider_type == onnxruntime::kOpenVINOExecutionProvider || - provider_type == onnxruntime::kNnapiExecutionProvider || - provider_type == onnxruntime::kVSINPUExecutionProvider || - provider_type == onnxruntime::kAclExecutionProvider || - provider_type == onnxruntime::kArmNNExecutionProvider || - provider_type == onnxruntime::kRknpuExecutionProvider || - provider_type == onnxruntime::kCoreMLExecutionProvider || - provider_type == onnxruntime::kSnpeExecutionProvider || - provider_type == onnxruntime::kQnnExecutionProvider || - provider_type == onnxruntime::kXnnpackExecutionProvider || - provider_type == onnxruntime::kAzureExecutionProvider || - provider_type == onnxruntime::utils::kInternalTestingExecutionProvider; +bool ProviderIsCpuBased(const IExecutionProvider& provider) { + return provider.GetDevice().Type() == OrtDevice::CPU; +} + +bool IsMemcpyNode(const Node& node) { + return node.Domain() == kOnnxDomain && + (node.OpType() == "MemcpyFromHost" || node.OpType() == "MemcpyToHost"); } static common::Status AllocateHelper(const AllocatorPtr& allocator, @@ -210,7 +201,7 @@ static Status BatchOrCopyMLValue(const SessionState& session_state, static bool HaveCpuExecutionProvidersOnly(const ExecutionProviders& execution_providers) { for (const auto& execution_provider : execution_providers) { - if (!ProviderIsCpuBased(execution_provider->Type())) { + if (!ProviderIsCpuBased(*execution_provider)) { return false; } } diff --git a/onnxruntime/core/framework/utils.h b/onnxruntime/core/framework/utils.h index 3a23093a5445b..4b4c483ba1202 100644 --- a/onnxruntime/core/framework/utils.h +++ b/onnxruntime/core/framework/utils.h @@ -52,12 +52,10 @@ void DestroyStrings(void* p_data, int64_t elements); const std::string& GetNodeInputProviderType(const SessionState::NodeInfo& info); -// EP used for internal testing. We define it here as it's used in ProviderIsCpuBased, but we don't want -// it to be in the public header include/onnxruntime/core/graph/constants.h as it's purely internal. -constexpr const char* kInternalTestingExecutionProvider = "InternalTestingExecutionProvider"; - // return true if the execution provider is CPU based (meaning no copies to device are required) -bool ProviderIsCpuBased(const std::string& provider_type); +bool ProviderIsCpuBased(const IExecutionProvider& provider); + +bool IsMemcpyNode(const Node& node); common::Status CopyOneInputAcrossDevices(const SessionState& session_state, const std::string& input_name, const OrtValue& orig_mlvalue, OrtValue& new_mlvalue); diff --git a/onnxruntime/core/optimizer/transformer_memcpy.cc b/onnxruntime/core/optimizer/transformer_memcpy.cc index cc7682b2b418d..9d49c16391f78 100644 --- a/onnxruntime/core/optimizer/transformer_memcpy.cc +++ b/onnxruntime/core/optimizer/transformer_memcpy.cc @@ -1,7 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "transformer_memcpy.h" +#include "core/optimizer/transformer_memcpy.h" + #include "core/common/logging/logging.h" #include "core/framework/kernel_registry_manager.h" #include "core/framework/execution_providers.h" @@ -12,18 +13,39 @@ using namespace ONNX_NAMESPACE; namespace onnxruntime { +static ProviderTypeToProviderMap GetProvidersByType( + const InlinedVector>& providers) { + ProviderTypeToProviderMap providers_by_type{}; + for (const auto provider : providers) { + providers_by_type.emplace(provider->Type(), provider); + } + return providers_by_type; +} + +MemcpyTransformer::MemcpyTransformer(InlinedVector> providers, + const KernelRegistryManager& registry_manager) + : GraphTransformer("MemcpyTransformer"), + providers_(std::move(providers)), + providers_by_type_(GetProvidersByType(providers_)), + registry_manager_(std::cref(registry_manager)) { +} + // implements MemCpy node insertion in graph transform // note that GraphTransformer::Apply() is supposed to be stateless, so this cannot derive from GraphTransformer class TransformerMemcpyImpl { public: - TransformerMemcpyImpl(onnxruntime::Graph& graph, const std::string& provider) - : graph_(graph), provider_(provider) {} + TransformerMemcpyImpl(onnxruntime::Graph& graph, const IExecutionProvider& provider, + const ProviderTypeToProviderMap& providers_by_type) + : graph_(graph), provider_(provider), providers_by_type_(providers_by_type) { + } bool ModifyGraph(const KernelRegistryManager& schema_registries, const logging::Logger& logger, int& copy_node_counter); private: + bool IsNodeCompatibleWithProvider(const onnxruntime::Node& node) const; + void ProcessDefs(onnxruntime::Node& node, const KernelRegistryManager& kernel_registries, InitializedTensorSet& initializers_consumed, @@ -31,7 +53,9 @@ class TransformerMemcpyImpl { void BuildDefsMapping(const onnxruntime::NodeArg* arg, const KernelRegistryManager& kernel_registries, const logging::Logger& logger); - void AddCopyNode(onnxruntime::NodeArg* arg, bool is_input, const logging::Logger& logger); + void AddCopyNode(onnxruntime::NodeArg* arg, + bool is_input, + const logging::Logger& logger); bool ProcessInitializers(const KernelRegistryManager& kernel_registries, const InitializedTensorSet& initializers_consumed, const logging::Logger& logger); @@ -55,7 +79,8 @@ class TransformerMemcpyImpl { std::map> provider_output_nodes_; onnxruntime::Graph& graph_; - std::string provider_; + const IExecutionProvider& provider_; + const ProviderTypeToProviderMap& providers_by_type_; }; /** Helper that returns a pointer to the corresponding TensorProto for a name if it is an initializer. @@ -73,17 +98,18 @@ static const onnx::TensorProto* GetInitializer(const Graph& graph, const std::st // very simple GraphTransformer that uses TransformerMemcpyImpl for each graph // and mainly provides the subgraph recursion functionality -common::Status MemcpyTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level, - const logging::Logger& logger) const { - for (auto& provider : provider_types_) { - if (!utils::ProviderIsCpuBased(provider)) { - TransformerMemcpyImpl copy_impl(graph, provider); +Status MemcpyTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level, + const logging::Logger& logger) const { + for (const auto provider : providers_) { + const auto& provider_type = provider->Type(); + if (!utils::ProviderIsCpuBased(*provider)) { + TransformerMemcpyImpl copy_impl(graph, *provider, providers_by_type_); int copy_node_counter = 0; auto current_modified = copy_impl.ModifyGraph(registry_manager_, logger, copy_node_counter); - if (copy_node_counter > 0 && provider == kCudaExecutionProvider) { + if (copy_node_counter > 0 && provider_type == kCudaExecutionProvider) { LOGS(logger, WARNING) << copy_node_counter << " Memcpy nodes are added to the graph " << graph.Name() - << " for " << provider + << " for " << provider_type << ". It might have negative impact on performance (including unable to run CUDA graph). " << "Set session_options.log_severity_level=1 to see the detail logs before this message."; } @@ -213,15 +239,42 @@ bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_regi return modified; } +static const IExecutionProvider* FindProviderByType(ProviderTypeToProviderMap providers_by_type, + std::string_view provider_type) { + const auto it = providers_by_type.find(provider_type); + if (it != providers_by_type.end()) { + return &*it->second; + } + return nullptr; +} + +bool TransformerMemcpyImpl::IsNodeCompatibleWithProvider(const onnxruntime::Node& node) const { + const auto& node_provider_type = node.GetExecutionProviderType(); + const auto* node_provider = FindProviderByType(providers_by_type_, node_provider_type); + ORT_ENFORCE(node_provider != nullptr, "Unable to get provider associated with provider type ", node_provider_type); + + // Same provider? + if (node_provider->Type() == provider_.Type()) { + return true; + } + + const auto& node_provider_device = node_provider->GetDevice(); + const auto& provider_device = provider_.GetDevice(); + + // Same provider device type and vendor? + if (node_provider_device.Type() == provider_device.Type() && + node_provider_device.Vendor() == provider_device.Vendor()) { + return true; + } + + return false; +} + void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelRegistryManager& kernel_registries, InitializedTensorSet& initializers_consumed, const logging::Logger& logger) { - auto node_provider_type = node.GetExecutionProviderType(); - if ((node_provider_type == provider_) || - (node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_) || - (node_provider_type == kCudaExecutionProvider && kNvTensorRTRTXExecutionProvider == provider_) || - (node_provider_type == kRocmExecutionProvider && kMIGraphXExecutionProvider == provider_)) { + if (IsNodeCompatibleWithProvider(node)) { provider_nodes_.insert(&node); // note KernelCreateInfo might be nullptr for custom kernel const KernelCreateInfo* kci = nullptr; @@ -268,9 +321,7 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, else provider_output_defs_.insert(arg); } - } else if (node_provider_type != kCudaExecutionProvider && node_provider_type != kTensorrtExecutionProvider && - node_provider_type != kCudaExecutionProvider && node_provider_type != kNvTensorRTRTXExecutionProvider && - node_provider_type != kRocmExecutionProvider && node_provider_type != kMIGraphXExecutionProvider) { + } else { for (const auto* arg : node.InputDefs()) { if (arg->Exists()) non_provider_input_defs_.insert(arg); @@ -297,7 +348,7 @@ void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg, const KernelRegistryManager& kernel_registries, const logging::Logger& logger) { for (auto& it : graph_.Nodes()) { - if (it.OpType() == "MemcpyFromHost" || it.OpType() == "MemcpyToHost") continue; + if (utils::IsMemcpyNode(it)) continue; auto input_it = std::find(it.MutableInputDefs().begin(), it.MutableInputDefs().end(), const_cast(arg)); auto output_it = @@ -309,10 +360,7 @@ void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg, if (arg_input_index == -1 && arg_output_index == -1) continue; auto node_provider_type = it.GetExecutionProviderType(); - if ((node_provider_type == provider_) || - (node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_) || - (node_provider_type == kCudaExecutionProvider && kNvTensorRTRTXExecutionProvider == provider_) || - (node_provider_type == kRocmExecutionProvider && kMIGraphXExecutionProvider == provider_)) { + if (IsNodeCompatibleWithProvider(it)) { const KernelCreateInfo* kci = nullptr; ORT_IGNORE_RETURN_VALUE(kernel_registries.SearchKernelRegistry(it, logger, &kci)); if (arg_input_index != -1) { @@ -325,9 +373,11 @@ void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg, } } -void TransformerMemcpyImpl::AddCopyNode(onnxruntime::NodeArg* arg, bool is_input, const logging::Logger& logger) { +void TransformerMemcpyImpl::AddCopyNode(onnxruntime::NodeArg* arg, + bool is_input, + const logging::Logger& logger) { // create unique name for new def - std::string new_def_name = graph_.GenerateNodeArgName(arg->Name() + "_" + provider_); + std::string new_def_name = graph_.GenerateNodeArgName(arg->Name() + "_" + provider_.Type()); auto* new_arg = &graph_.GetOrCreateNodeArg(new_def_name, arg->TypeAsProto()); auto* src_arg = is_input ? arg : new_arg; @@ -338,12 +388,14 @@ void TransformerMemcpyImpl::AddCopyNode(onnxruntime::NodeArg* arg, bool is_input const auto op_name = is_input ? "MemcpyFromHost" : "MemcpyToHost"; LOGS(logger, INFO) << "Add " << op_name << (is_input ? " after " : " before ") << arg->Name() - << " for " << provider_; + << " for " << provider_.Type(); auto& new_node = graph_.AddNode(new_node_name, op_name, "Copy from/to host memory", std::vector{src_arg}, std::vector{dst_arg}); - new_node.SetExecutionProviderType(provider_); + + new_node.SetExecutionProviderType(provider_.Type()); + std::map map = {{arg, new_arg}}; auto it = provider_input_nodes_.find(arg); if (it != provider_input_nodes_.end()) { diff --git a/onnxruntime/core/optimizer/transformer_memcpy.h b/onnxruntime/core/optimizer/transformer_memcpy.h index a2403d269f89b..f6b60a83fcf32 100644 --- a/onnxruntime/core/optimizer/transformer_memcpy.h +++ b/onnxruntime/core/optimizer/transformer_memcpy.h @@ -5,13 +5,19 @@ #include +#include "gsl/gsl" + #include "core/common/common.h" +#include "core/common/inlined_containers.h" +#include "core/framework/execution_provider.h" #include "core/framework/op_kernel.h" #include "core/framework/kernel_registry_manager.h" #include "core/optimizer/graph_transformer.h" namespace onnxruntime { +using ProviderTypeToProviderMap = InlinedHashMap>; + /** @Class MemcpyTransformer @@ -19,13 +25,14 @@ Transformer that inserts nodes to copy memory between devices when needed. */ class MemcpyTransformer : public GraphTransformer { public: - MemcpyTransformer(const std::vector& provider_types, const KernelRegistryManager& registry_manager) - : GraphTransformer("MemcpyTransformer"), provider_types_(provider_types), registry_manager_(std::cref(registry_manager)) {} + MemcpyTransformer(InlinedVector> providers, + const KernelRegistryManager& registry_manager); private: - common::Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; - const std::vector provider_types_; + const InlinedVector> providers_; + const ProviderTypeToProviderMap providers_by_type_; std::reference_wrapper registry_manager_; }; diff --git a/onnxruntime/core/platform/telemetry.cc b/onnxruntime/core/platform/telemetry.cc index 9cf89a04f031c..6cbbdd4e0a7ef 100644 --- a/onnxruntime/core/platform/telemetry.cc +++ b/onnxruntime/core/platform/telemetry.cc @@ -40,13 +40,16 @@ void Telemetry::SetLanguageProjection(uint32_t projection) const { void Telemetry::LogProcessInfo() const { } -void Telemetry::LogSessionCreationStart() const { +void Telemetry::LogSessionCreationStart(uint32_t session_id) const { + ORT_UNUSED_PARAMETER(session_id); } -void Telemetry::LogEvaluationStop() const { +void Telemetry::LogEvaluationStop(uint32_t session_id) const { + ORT_UNUSED_PARAMETER(session_id); } -void Telemetry::LogEvaluationStart() const { +void Telemetry::LogEvaluationStart(uint32_t session_id) const { + ORT_UNUSED_PARAMETER(session_id); } void Telemetry::LogSessionCreation(uint32_t session_id, int64_t ir_version, const std::string& model_producer_name, diff --git a/onnxruntime/core/platform/telemetry.h b/onnxruntime/core/platform/telemetry.h index cb7a6176e5aec..b60345e1b8a80 100644 --- a/onnxruntime/core/platform/telemetry.h +++ b/onnxruntime/core/platform/telemetry.h @@ -48,11 +48,11 @@ class Telemetry { virtual void LogProcessInfo() const; - virtual void LogSessionCreationStart() const; + virtual void LogSessionCreationStart(uint32_t session_id) const; - virtual void LogEvaluationStop() const; + virtual void LogEvaluationStop(uint32_t session_id) const; - virtual void LogEvaluationStart() const; + virtual void LogEvaluationStart(uint32_t session_id) const; virtual void LogSessionCreation(uint32_t session_id, int64_t ir_version, const std::string& model_producer_name, const std::string& model_producer_version, const std::string& model_domain, diff --git a/onnxruntime/core/platform/windows/telemetry.cc b/onnxruntime/core/platform/windows/telemetry.cc index 44ef44a3f5aff..2e5d334856278 100644 --- a/onnxruntime/core/platform/windows/telemetry.cc +++ b/onnxruntime/core/platform/windows/telemetry.cc @@ -194,7 +194,7 @@ void WindowsTelemetry::LogProcessInfo() const { process_info_logged = true; } -void WindowsTelemetry::LogSessionCreationStart() const { +void WindowsTelemetry::LogSessionCreationStart(uint32_t session_id) const { if (global_register_count_ == 0 || enabled_ == false) return; @@ -203,23 +203,26 @@ void WindowsTelemetry::LogSessionCreationStart() const { TraceLoggingBool(true, "UTCReplace_AppSessionGuid"), TelemetryPrivacyDataTag(PDT_ProductAndServiceUsage), TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), + TraceLoggingUInt32(session_id, "sessionId"), TraceLoggingLevel(WINEVENT_LEVEL_INFO)); } -void WindowsTelemetry::LogEvaluationStop() const { +void WindowsTelemetry::LogEvaluationStop(uint32_t session_id) const { if (global_register_count_ == 0 || enabled_ == false) return; TraceLoggingWrite(telemetry_provider_handle, - "EvaluationStop"); + "EvaluationStop", + TraceLoggingUInt32(session_id, "sessionId")); } -void WindowsTelemetry::LogEvaluationStart() const { +void WindowsTelemetry::LogEvaluationStart(uint32_t session_id) const { if (global_register_count_ == 0 || enabled_ == false) return; TraceLoggingWrite(telemetry_provider_handle, - "EvaluationStart"); + "EvaluationStart", + TraceLoggingUInt32(session_id, "sessionId")); } void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_version, const std::string& model_producer_name, diff --git a/onnxruntime/core/platform/windows/telemetry.h b/onnxruntime/core/platform/windows/telemetry.h index 7281063d50c2e..261d14a7fed8c 100644 --- a/onnxruntime/core/platform/windows/telemetry.h +++ b/onnxruntime/core/platform/windows/telemetry.h @@ -41,11 +41,11 @@ class WindowsTelemetry : public Telemetry { void LogProcessInfo() const override; - void LogSessionCreationStart() const override; + void LogSessionCreationStart(uint32_t session_id) const override; - void LogEvaluationStop() const override; + void LogEvaluationStop(uint32_t session_id) const override; - void LogEvaluationStart() const override; + void LogEvaluationStart(uint32_t session_id) const override; void LogSessionCreation(uint32_t session_id, int64_t ir_version, const std::string& model_producer_name, const std::string& model_producer_version, const std::string& model_domain, diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 5eac0523d953a..1030e368a5fd6 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -4,6 +4,7 @@ #include "core/providers/cpu/cpu_execution_provider.h" #include "core/framework/allocator_utils.h" +#include "core/framework/memcpy.h" #include "core/framework/op_kernel.h" #include "core/framework/kernel_registry.h" #include "core/framework/int4.h" @@ -27,6 +28,33 @@ struct KernelRegistryAndStatus { } // namespace namespace onnxruntime { + +// The MemcpyFromHost and MemcpyToHost kernels registered for the CPU EP are generic memcpy kernels. +// Other EPs may provide their own memcpy kernels. +// For a memcpy between host (CPU) and device of some other EP: +// - If the EP provides the corresponding memcpy kernel, it will be used. +// - Otherwise, one of these generic memcpy kernels will be used. + +ONNX_OPERATOR_KERNEL_EX( + MemcpyFromHost, + kOnnxDomain, + 1, + kCpuExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorAndSequenceTensorTypesIRv9()), + Memcpy); + +ONNX_OPERATOR_KERNEL_EX( + MemcpyToHost, + kOnnxDomain, + 1, + kCpuExecutionProvider, + (*KernelDefBuilder::Create()) + .OutputMemoryType(OrtMemTypeCPUOutput, 0) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorAndSequenceTensorTypesIRv9()), + Memcpy); + CPUExecutionProvider::CPUExecutionProvider(const CPUExecutionProviderInfo& info) : IExecutionProvider{onnxruntime::kCpuExecutionProvider}, info_{info} {} @@ -39,6 +67,8 @@ std::vector CPUExecutionProvider::CreatePreferredAllocators() { } // Forward declarations of op kernels +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MemcpyFromHost); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MemcpyToHost); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 10, Clip); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 21, Elu); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 21, HardSigmoid); @@ -1379,6 +1409,8 @@ KernelCreateInfo BuildKernelCreateInfo() { Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/llm/attention.cc b/onnxruntime/core/providers/cpu/llm/attention.cc index 4fc3cdf961d17..9d3736173dae2 100644 --- a/onnxruntime/core/providers/cpu/llm/attention.cc +++ b/onnxruntime/core/providers/cpu/llm/attention.cc @@ -47,14 +47,14 @@ void make_copy(MLFloat16* mask_data, const MLFloat16* mask template <> void make_copy(float* mask_data, const bool* mask_index, size_t size) { for (size_t i = 0; i < size; ++i) { - mask_data[i] = mask_index[i] ? 0.0f : std::numeric_limits::lowest(); + mask_data[i] = mask_index[i] ? 0.0f : negative_infinity(); } } template <> void make_copy(MLFloat16* mask_data, const bool* mask_index, size_t size) { for (size_t i = 0; i < size; ++i) { - mask_data[i] = mask_index[i] ? MLFloat16(0.f) : std::numeric_limits::lowest(); + mask_data[i] = mask_index[i] ? MLFloat16(0.f) : negative_infinity(); } } @@ -236,7 +236,7 @@ void AttentionBase::ComputeAttentionProbs(T* attention_probs, mask_data = static_cast(allocated_ptr); for (int s_i = 0; s_i < parameters.q_sequence_length; s_i++) { for (int m_i = parameters.past_sequence_length + s_i + 1; m_i < parameters.total_sequence_length; m_i++) { - mask_data[s_i * parameters.total_sequence_length + m_i] = std::numeric_limits::lowest(); + mask_data[s_i * parameters.total_sequence_length + m_i] = negative_infinity(); } } delete_mask_data = true; @@ -262,7 +262,7 @@ void AttentionBase::ComputeAttentionProbs(T* attention_probs, for (int i = 0; i < n_iter; ++i) { for (int s_i = 0; s_i < parameters.q_sequence_length; s_i++) { for (int m_i = parameters.past_sequence_length + s_i + 1; m_i < parameters.total_sequence_length; m_i++) { - mask_data[s_i * parameters.total_sequence_length + m_i + probs_matrix_size * i] = std::numeric_limits::lowest(); + mask_data[s_i * parameters.total_sequence_length + m_i + probs_matrix_size * i] = negative_infinity(); } } } @@ -317,7 +317,8 @@ void AttentionBase::ComputeAttentionProbs(T* attention_probs, } // handling GQA - std::ptrdiff_t ki = batch_i * parameters.kv_num_heads + head_i % parameters.kv_num_heads; + std::ptrdiff_t head_ki = head_i * parameters.kv_num_heads / parameters.q_num_heads; + std::ptrdiff_t ki = batch_i * parameters.kv_num_heads + head_ki; const T* k = K + k_input_chunk_length * ki; if (nullptr != present_key) { @@ -347,7 +348,7 @@ void AttentionBase::ComputeAttentionProbs(T* attention_probs, alpha, Q + q_input_chunk_length * parameters.q_num_heads * batch_i + head_i * parameters.head_size, parameters.head_size * parameters.q_num_heads, // lda - transposed_k ? K + k_input_chunk_length * parameters.kv_num_heads * batch_i + head_i * parameters.head_size : k, + transposed_k ? K + k_input_chunk_length * parameters.kv_num_heads * batch_i + head_ki * parameters.head_size : k, transposed_k ? parameters.head_size * parameters.kv_num_heads : parameters.head_size, // ldb beta, output, @@ -555,7 +556,8 @@ void AttentionBase::ComputeVxAttentionScore(T* output, // bu // handling GQA std::ptrdiff_t batch_i = i / num_heads; std::ptrdiff_t head_i = i % num_heads; - std::ptrdiff_t vi = batch_i * kv_num_heads + head_i % kv_num_heads; + std::ptrdiff_t head_vi = head_i * kv_num_heads / num_heads; + std::ptrdiff_t vi = batch_i * kv_num_heads + head_vi; const T* v = V + v_input_chunk_length * vi; if (nullptr != present_value) { @@ -579,15 +581,15 @@ void AttentionBase::ComputeVxAttentionScore(T* output, // bu // V is transposed but not QK. We use GemmEx with a different value for ldb. math::GemmEx(CblasNoTrans, CblasNoTrans, - sequence_length, // M - v_head_size, // N - total_sequence_length, // K - 1.f, // alpha - attention_probs + attention_probs_offset, // QK - total_sequence_length, // lda - transposed_v ? V + head_i * v_head_size + v_input_chunk_length * kv_num_heads * batch_i : v, // V - transposed_v ? v_head_size * kv_num_heads : v_head_size, // ldb - 0.f, // beta + sequence_length, // M + v_head_size, // N + total_sequence_length, // K + 1.f, // alpha + attention_probs + attention_probs_offset, // QK + total_sequence_length, // lda + transposed_v ? V + head_vi * v_head_size + v_input_chunk_length * kv_num_heads * batch_i : v, // V + transposed_v ? v_head_size * kv_num_heads : v_head_size, // ldb + 0.f, // beta output + ((batch_i * sequence_length * num_heads + head_i) * v_head_size), v_head_size * num_heads, // ldc nullptr); diff --git a/onnxruntime/core/providers/cpu/llm/attention.h b/onnxruntime/core/providers/cpu/llm/attention.h index 78889e48afb29..4fad6914f933d 100644 --- a/onnxruntime/core/providers/cpu/llm/attention.h +++ b/onnxruntime/core/providers/cpu/llm/attention.h @@ -9,6 +9,16 @@ namespace onnxruntime { +template +inline T negative_infinity() { + return -std::numeric_limits::infinity(); +} + +template <> +inline MLFloat16 negative_infinity() { + return MLFloat16(-std::numeric_limits::infinity()); +} + template class AttentionBase : public OpKernel { public: diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index 0ee18cc6799fc..62210d65848d1 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -1466,12 +1466,15 @@ std::unique_ptr NvExecutionProvider::GetSubGraph(SubGraph_t gra fused_inputs.erase(it); erased.insert(output); } - // Only when output is neither in input list nor erased list, add the output to output list + // Only when output is neither in input list nor erased list, and the output is consumed by another node, add the output to output list else if (erased.find(output) == erased.end()) { if (graph_output_names.find(output->Name()) != graph_output_names.end()) { graph_outputs_to_add[output] = output_order; } - fused_outputs[output] = output_order++; + + if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) { + fused_outputs[output] = output_order++; + } } } } diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 3acb3347acee1..4a6545a0e6f0a 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -1576,6 +1576,17 @@ Status QNNExecutionProvider::SetEpDynamicOptions(gsl::span ke LOGS_DEFAULT(ERROR) << "Invalid EP Workload Type: " << value; return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid EP Workload Type."); } + } else if (key == kOrtEpDynamicOptionsQnnHtpPerformanceMode) { + auto backend_type = qnn_backend_manager_->GetQnnBackendType(); + if (qnn::QnnBackendType::HTP != backend_type && qnn::QnnBackendType::DSP != backend_type) { + return Status::OK(); + } + qnn::HtpPerformanceMode htp_performance_mode = qnn::HtpPerformanceMode::kHtpDefault; + ParseHtpPerformanceMode(value, htp_performance_mode); + if (GetPerThreadContext().IsHtpPowerConfigIdValid()) { + ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetHtpPowerConfig(GetPerThreadContext().GetHtpPowerConfigId(), + htp_performance_mode)); + } } else { LOGS_DEFAULT(ERROR) << "EP Dynamic Option \"" << key << "\" is not currently supported."; return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported EP Dynamic Option"); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index b60f64db1734d..508d932459bf9 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2114,12 +2114,15 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph fused_inputs.erase(it); erased.insert(output); } - // Only when output is neither in input list nor erased list, add the output to output list + // Only when output is neither in input list nor erased list, and the output is consumed by another node, add the output to output list else if (erased.find(output) == erased.end()) { if (graph_output_names.find(output->Name()) != graph_output_names.end()) { graph_outputs_to_add[output] = output_order; } - fused_outputs[output] = output_order++; + + if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) { + fused_outputs[output] = output_order++; + } } } } diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index c0900c5ad28a0..e3291cdce62c5 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1517,12 +1517,12 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool // Insert copy node/s. { - std::vector provider_types; + InlinedVector> providers; for (auto& provider_ptr : execution_providers_) { - provider_types.push_back(provider_ptr->Type()); + providers.push_back(provider_ptr.get()); } - MemcpyTransformer copy_transformer{provider_types, kernel_registry_manager_}; + MemcpyTransformer copy_transformer{std::move(providers), kernel_registry_manager_}; ORT_RETURN_IF_ERROR_SESSIONID_(apply_transformer_once(copy_transformer, *session_logger_, graph)); } @@ -2041,7 +2041,7 @@ common::Status InferenceSession::Initialize() { ORT_TRY { LOGS(*session_logger_, INFO) << "Initializing session."; const Env& env = Env::Default(); - env.GetTelemetryProvider().LogSessionCreationStart(); + env.GetTelemetryProvider().LogSessionCreationStart(session_id_); bool have_cpu_ep = false; @@ -2980,7 +2980,7 @@ Status InferenceSession::Run(const RunOptions& run_options, } // log evaluation start to trace logging provider - env.GetTelemetryProvider().LogEvaluationStart(); + env.GetTelemetryProvider().LogEvaluationStart(session_id_); ORT_RETURN_IF_ERROR_SESSIONID_(ValidateInputs(feed_names, feeds)); ORT_RETURN_IF_ERROR_SESSIONID_(ValidateOutputs(output_names, p_fetches)); @@ -3133,7 +3133,7 @@ Status InferenceSession::Run(const RunOptions& run_options, } // log evaluation stop to trace logging provider - env.GetTelemetryProvider().LogEvaluationStop(); + env.GetTelemetryProvider().LogEvaluationStop(session_id_); // send out profiling events (optional) if (session_profiler_.IsEnabled()) { @@ -3383,17 +3383,58 @@ common::Status InferenceSession::GetInputOutputMemoryInfo(SessionInputOutputType for (const auto* def : def_list) { InlinedVector node_info_vec; + Status status; if (type == SessionInputOutputType::kOutput) { - ORT_RETURN_IF_ERROR(session_state_->GetOutputNodeInfo(def->Name(), node_info_vec)); + status = session_state_->GetOutputNodeInfo(def->Name(), node_info_vec); } else { - ORT_RETURN_IF_ERROR(session_state_->GetInputNodeInfo(def->Name(), node_info_vec)); + status = session_state_->GetInputNodeInfo(def->Name(), node_info_vec); } - // all entries are for the same OrtDevice so use the first one. - // we need to get an OrtMemoryInfo* that will remain valid, so we get the allocator for the OrtDevice - // from the session state and use its OrtMemoryInfo. - auto allocator = session_state_->GetAllocator(*node_info_vec.front().device); - memory_info.push_back(&allocator->Info()); + if (!status.IsOK()) { + if (type == SessionInputOutputType::kInput) { + return status; + } + + // Check first if this output is produced by an input that directly + // propagates to output with the same name. + status = session_state_->GetInputNodeInfo(def->Name(), node_info_vec); + if (status.IsOK()) { + // all entries are for the same OrtDevice so use the first one. + // we need to get an OrtMemoryInfo* that will remain valid, so we get the allocator for the OrtDevice + // from the session state and use its OrtMemoryInfo. + auto allocator = session_state_->GetAllocator(*node_info_vec.front().device); + memory_info.push_back(&allocator->Info()); + } else { + // Check if this output is produced by a constant initializer + // Pick the MemoryInfo from the initializer's OrtValue + const auto& ort_value_map = session_state_->GetOrtValueNameIdxMap(); + + OrtValueIndex ort_value_index; + status = ort_value_map.GetIdx(def->Name(), ort_value_index); + if (!status.IsOK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Failed to find node output or a constant initializer producing output: ", + def->Name(), "."); + } + + const auto& idx_to_ort_value = session_state_->GetInitializedTensors(); + auto it = idx_to_ort_value.find(ort_value_index); + if (it == idx_to_ort_value.end()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Failed to find node output or a constant initializer producing output: ", + def->Name(), "."); + } + const auto& tensor = it->second.Get(); + auto allocator = session_state_->GetAllocator(tensor.Location()); + memory_info.push_back(&allocator->Info()); + } + } else { + // all entries are for the same OrtDevice so use the first one. + // we need to get an OrtMemoryInfo* that will remain valid, so we get the allocator for the OrtDevice + // from the session state and use its OrtMemoryInfo. + auto allocator = session_state_->GetAllocator(*node_info_vec.front().device); + memory_info.push_back(&allocator->Info()); + } } return Status::OK(); @@ -3422,15 +3463,19 @@ common::Status InferenceSession::GetEpDeviceForInputs(InlinedVector node_info_vec; ORT_RETURN_IF_ERROR(session_state_->GetInputNodeInfo(def->Name(), node_info_vec)); - - // if we have a lot of inputs or there are a lot of execution providers it may be worth creating a map - // instead of doing a linear search each time. - const auto& ep_name = node_info_vec.front().p_node->GetExecutionProviderType(); - auto it = std::find_if(available_eps.begin(), available_eps.end(), [&ep_name](const OrtEpDevice* entry) { - return entry->ep_name == ep_name; - }); - - ep_devices.push_back(it != available_eps.end() ? *it : nullptr); + assert(!node_info_vec.empty()); + // If we have an input that is not consumed by any node, + // including nodes in subgraphs, then we return nullptr. + const auto* p_node = node_info_vec.front().p_node; + if (p_node != nullptr) { + const auto ep_name = p_node->GetExecutionProviderType(); + auto it = std::find_if(available_eps.begin(), available_eps.end(), [&ep_name](const OrtEpDevice* entry) { + return entry->ep_name == ep_name; + }); + ep_devices.push_back(it != available_eps.end() ? *it : nullptr); + } else { + ep_devices.push_back(nullptr); + } } return Status::OK(); diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 21d09df5cc4db..edc0cb6d2bd0f 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -4257,7 +4257,7 @@ static_assert(offsetof(OrtApi, SetEpDynamicOptions) / sizeof(void*) == 284, "Siz static_assert(offsetof(OrtApi, GetEpApi) / sizeof(void*) == 317, "Size of version 22 API cannot change"); // So that nobody forgets to finish an API version, this check will serve as a reminder: -static_assert(std::string_view(ORT_VERSION) == "1.23.0", +static_assert(std::string_view(ORT_VERSION) == "1.23.1", "ORT_Version change detected, please follow below steps to ensure OrtApi is updated properly"); // 1. Update the hardcoded version string in above static_assert to silence it // 2. If there were any APIs added to ort_api_1_to_23 above: diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc index c8829423fbe26..55245420db37a 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc @@ -3,6 +3,7 @@ #include "core/session/plugin_ep/ep_plugin_provider_interfaces.h" +#include #include #include #include @@ -117,6 +118,17 @@ static OrtDevice GetOrtDeviceForPluginEp(gsl::span ep_ return device_memory_info != nullptr ? device_memory_info->device : OrtDevice(); } +static const Node* FindFirstNodeAssignedToOtherEP(const std::string& ep_type, + gsl::span ep_nodes) { + auto node_iter = std::find_if(ep_nodes.begin(), ep_nodes.end(), + [&ep_type](const EpNode* node) -> bool { + const auto& node_ep_type = node->GetInternalNode().GetExecutionProviderType(); + return !node_ep_type.empty() && node_ep_type != ep_type; + }); + + return node_iter != ep_nodes.end() ? &(*node_iter)->GetInternalNode() : nullptr; +} + PluginExecutionProvider::PluginExecutionProvider(UniqueOrtEp ep, const OrtSessionOptions& session_options, OrtEpFactory& ep_factory, gsl::span ep_devices, @@ -158,9 +170,11 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie ORT_UNUSED_PARAMETER(resource_accountant); // TODO: Add support? Not used by prioritized EPs ORT_UNUSED_PARAMETER(kernel_lookup); // TODO: Add support? Not used by prioritized EPs, so probably not needed? + const logging::Logger& logger = GetLogger() != nullptr ? *GetLogger() : logging::LoggingManager::DefaultLogger(); + std::unique_ptr ep_graph = nullptr; if (Status status = EpGraph::Create(graph_viewer, ep_graph); !status.IsOK()) { - LOGS_DEFAULT(ERROR) << "Failed to create OrtGraph: " << status.ToString(); + LOGS(logger, ERROR) << "Failed to create OrtGraph for " << Type() << ": " << status.ToString(); return {}; } @@ -168,7 +182,7 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie Status status = ToStatusAndRelease(ort_ep_->GetCapability(ort_ep_.get(), ep_graph->ToExternal(), &api_graph_support_info)); if (!status.IsOK()) { - LOGS_DEFAULT(ERROR) << "OrtEp::GetCapability() failed with error: " << status.ToString(); + LOGS(logger, ERROR) << "OrtEp::GetCapability() for " << Type() << " failed with error: " << status.ToString(); return {}; } @@ -182,12 +196,39 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie // Create ComputeCapability instances from OrtEpGraphSupportInfo::NodeGrouping instances. for (const OrtEpGraphSupportInfo::NodeGrouping& node_grouping : api_graph_support_info.node_groupings) { + // Skip this node grouping if any node has already been assigned to another EP. + if (const Node* node_for_other_ep = FindFirstNodeAssignedToOtherEP(Type(), node_grouping.nodes); + node_for_other_ep != nullptr) { + LOGS(logger, WARNING) << "OrtEp::GetCapability() specified nodes that cannot be assigned to " << Type() << ". " + << "Found one or more nodes that were already assigned to a different EP named '" + << node_for_other_ep->GetExecutionProviderType() << "'. Ex: " + << node_for_other_ep->OpType() << " node with name '" + << node_for_other_ep->Name() << "'."; + continue; + } + if (node_grouping.kind == OrtEpGraphSupportInfo::NodeGroupingKind::kSingleAssignedNode) { + if (node_grouping.nodes.size() != 1) { + // The EpGraphSupportInfo_AddSingleNode() C API should already return an error if the EP tries to provide + // an invalid node. However, we check here too just in case this changes. + LOGS(logger, ERROR) << "OrtEp::GetCapability() for " << Type() << " did not specify exactly one valid node " + << "when calling EpGraphSupportInfo_AddSingleNode()."; + return {}; + } + auto indexed_sub_graph = std::make_unique(); indexed_sub_graph->nodes.push_back(node_grouping.nodes[0]->GetInternalNode().Index()); result.push_back(std::make_unique(std::move(indexed_sub_graph))); } else if (node_grouping.kind == OrtEpGraphSupportInfo::NodeGroupingKind::kFusedNode) { + if (node_grouping.nodes.empty()) { + // The EpGraphSupportInfo_AddNodesToFuse() C API should already return an error if the EP tries to provide + // an empty array of nodes from OrtEp::GetCapability(). However, we check here too just in case this changes. + LOGS(logger, ERROR) << "OrtEp::GetCapability() for " << Type() << " set an empty array of nodes " + << "when specifying supported nodes."; + return {}; + } + std::unordered_set node_set; node_set.reserve(node_grouping.nodes.size()); @@ -207,27 +248,29 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie this->Type(), this->Type(), /*node_unit_map*/ nullptr, node_grouping.fusion_options.drop_constant_initializers); - if (capabilities.size() > 1) { - LOGS_DEFAULT(ERROR) << "OrtEp::GetCapability() set nodes that cannot be fused together. " - << "Please ensure that the nodes provided to EpGraphSupportInfo_AddFusedNodes() do not " + if (capabilities.size() != 1) { + LOGS(logger, ERROR) << "OrtEp::GetCapability() for " << Type() << " set nodes that cannot be fused together. " + << "Please ensure that the nodes provided to EpGraphSupportInfo_AddNodesToFuse() do not " << "have an unsupported node in any path between two of the supported nodes."; return {}; } - // Enforce that the nodes in node_set match the nodes in capabilities[0] + // Log an error if the nodes in node_set do not match the nodes in capabilities[0]. We expect this to always + // be true because we've already checked that the EP did not try to claim nodes already assigned to another EP. // TODO(adrianlizarraga): This check can be removed when we stop using utils::CreateSupportedPartitions() above. std::vector& capability_node_indices = capabilities[0]->sub_graph->nodes; std::unordered_set capability_node_indices_set(capability_node_indices.begin(), capability_node_indices.end()); - ORT_ENFORCE(node_set.size() == capability_node_indices_set.size()); - ORT_ENFORCE(std::all_of(node_set.begin(), node_set.end(), [&capability_node_indices_set](const Node* node) { - return capability_node_indices_set.count(node->Index()) != 0; - })); + if (node_set.size() != capability_node_indices_set.size()) { + LOGS(logger, ERROR) << "OrtEp::GetCapability() for " << Type() + << " set nodes that cannot all be fused together."; + return {}; + } result.push_back(std::move(capabilities[0])); } else { - LOGS_DEFAULT(ERROR) << "PluginExecutionProvider::GetCapability() has invalid NodeGroupingKind: " + LOGS(logger, ERROR) << "PluginExecutionProvider::GetCapability() has invalid NodeGroupingKind: " << static_cast(node_grouping.kind); return {}; } diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index 35abad5760c32..4c3313046457c 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -199,6 +199,18 @@ def get_modelmeta(self) -> onnxruntime.ModelMetadata: "Return the metadata. See :class:`onnxruntime.ModelMetadata`." return self._model_meta + def get_input_memory_infos(self) -> Sequence[onnxruntime.MemoryInfo]: + "Return the memory info for the inputs." + return self._input_meminfos + + def get_output_memory_infos(self) -> Sequence[onnxruntime.MemoryInfo]: + "Return the memory info for the outputs." + return self._output_meminfos + + def get_input_epdevices(self) -> Sequence[onnxruntime.OrtEpDevice]: + "Return the execution providers for the inputs." + return self._input_epdevices + def get_providers(self) -> Sequence[str]: "Return list of registered execution providers." return self._providers @@ -576,6 +588,9 @@ def _create_inference_session(self, providers, provider_options, disabled_optimi self._inputs_meta = self._sess.inputs_meta self._outputs_meta = self._sess.outputs_meta self._overridable_initializers = self._sess.overridable_initializers + self._input_meminfos = self._sess.input_meminfos + self._output_meminfos = self._sess.output_meminfos + self._input_epdevices = self._sess.input_epdevices self._model_meta = self._sess.model_meta self._providers = self._sess.get_providers() self._provider_options = self._sess.get_provider_options() @@ -589,6 +604,9 @@ def _reset_session(self, providers, provider_options) -> None: self._inputs_meta = None self._outputs_meta = None self._overridable_initializers = None + self._input_meminfos = None + self._output_meminfos = None + self._input_epdevices = None self._model_meta = None self._providers = None self._provider_options = None @@ -1134,6 +1152,15 @@ def update_inplace(self, np_arr) -> None: self._ortvalue.update_inplace(np_arr) +def copy_tensors(src: Sequence[OrtValue], dst: Sequence[OrtValue], stream=None) -> None: + """ + Copy tensor data from source OrtValue sequence to destination OrtValue sequence. + """ + c_sources = [s._get_c_value() for s in src] + c_dsts = [d._get_c_value() for d in dst] + C.copy_tensors(c_sources, c_dsts, stream) + + class OrtDevice: """ A data structure that exposes the underlying C++ OrtDevice @@ -1146,6 +1173,7 @@ def __init__(self, c_ort_device): if isinstance(c_ort_device, C.OrtDevice): self._ort_device = c_ort_device else: + # An end user won't hit this error raise ValueError( "`Provided object` needs to be of type `onnxruntime.capi.onnxruntime_pybind11_state.OrtDevice`" ) @@ -1188,6 +1216,9 @@ def device_type(self): def device_vendor_id(self): return self._ort_device.vendor_id() + def device_mem_type(self): + return self._ort_device.mem_type() + class SparseTensor: """ diff --git a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc index 1fe7ab0884f9c..d74663ddb63d7 100644 --- a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc @@ -333,7 +333,7 @@ void addOrtValueMethods(pybind11::module& m) { }) #endif // Get a pointer to Tensor data - .def("data_ptr", [](OrtValue* ml_value) -> int64_t { + .def("data_ptr", [](OrtValue* ml_value) -> uintptr_t { // TODO: Assumes that the OrtValue is a Tensor, make this generic to handle non-Tensors ORT_ENFORCE(ml_value->IsTensor(), "Only OrtValues that are Tensors are currently supported"); @@ -344,7 +344,7 @@ void addOrtValueMethods(pybind11::module& m) { } // Should cover x86 and x64 platforms - return reinterpret_cast(tensor->MutableDataRaw()); + return reinterpret_cast(tensor->MutableDataRaw()); }) .def("device_name", [](const OrtValue* ort_value) -> std::string { if (ort_value->IsTensor()) { diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index e370518b1fffb..479898beae83e 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -22,6 +22,7 @@ #include "core/framework/data_transfer_utils.h" #include "core/framework/data_types_internal.h" #include "core/framework/error_code_helper.h" +#include "core/framework/plugin_ep_stream.h" #include "core/framework/provider_options_utils.h" #include "core/framework/random_seed.h" #include "core/framework/sparse_tensor.h" @@ -1584,6 +1585,18 @@ void addGlobalMethods(py::module& m) { }, R"pbdoc("Validate a compiled model's compatibility information for one or more EP devices.)pbdoc"); + m.def( + "copy_tensors", + [](const std::vector& src, const std::vector& dest, py::object& py_arg) { + const OrtEnv* ort_env = GetOrtEnv(); + OrtSyncStream* stream = nullptr; + if (!py_arg.is_none()) { + stream = py_arg.cast(); + } + Ort::ThrowOnError(Ort::GetApi().CopyTensors(ort_env, src.data(), dest.data(), stream, src.size())); + }, + R"pbdoc("Copy tensors from sources to destinations using specified stream handle (or None))pbdoc"); + #if defined(USE_OPENVINO) || defined(USE_OPENVINO_PROVIDER_INTERFACE) m.def( "get_available_openvino_device_ids", []() -> std::vector { @@ -1785,6 +1798,16 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra .value("CPU", OrtMemTypeCPU) .value("DEFAULT", OrtMemTypeDefault); + py::enum_(m, "OrtMemoryInfoDeviceType") + .value("CPU", OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_CPU) + .value("GPU", OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU) + .value("NPU", OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_NPU) + .value("FPGA", OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_FPGA); + + py::enum_(m, "OrtDeviceMemoryType") + .value("DEFAULT", OrtDeviceMemoryType_DEFAULT) + .value("HOST_ACCESSIBLE", OrtDeviceMemoryType_HOST_ACCESSIBLE); + py::class_ device(m, "OrtDevice", R"pbdoc(ONNXRuntime device information.)pbdoc"); device.def(py::init()) .def(py::init([](OrtDevice::DeviceType type, @@ -1813,6 +1836,7 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra .def("device_id", &OrtDevice::Id, R"pbdoc(Device Id.)pbdoc") .def("device_type", &OrtDevice::Type, R"pbdoc(Device Type.)pbdoc") .def("vendor_id", &OrtDevice::Vendor, R"pbdoc(Vendor Id.)pbdoc") + .def("mem_type", &OrtDevice::MemType, R"pbdoc(Device Memory Type.)pbdoc") // generic device types that are typically used with a vendor id. .def_static("cpu", []() { return OrtDevice::CPU; }) .def_static("gpu", []() { return OrtDevice::GPU; }) @@ -1863,36 +1887,58 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra }, R"pbdoc(Hardware device's metadata as string key/value pairs.)pbdoc"); + py::class_ py_sync_stream(m, "OrtSyncStream", + R"pbdoc(Represents a synchronization stream for model inference.)pbdoc"); + py_sync_stream.def("get_handle", [](OrtSyncStream* stream) -> uintptr_t { + Ort::UnownedSyncStream ort_stream(stream); + return reinterpret_cast(ort_stream.GetHandle()); }, R"pbdoc(SyncStream handle that can be converted to a string and added to SessionOptions)pbdoc"); + py::class_ py_ep_device(m, "OrtEpDevice", R"pbdoc(Represents a hardware device that an execution provider supports for model inference.)pbdoc"); py_ep_device.def_property_readonly( "ep_name", - [](OrtEpDevice* ep_device) -> std::string { return ep_device->ep_name; }, + [](const OrtEpDevice* ep_device) -> std::string { return ep_device->ep_name; }, R"pbdoc(The execution provider's name.)pbdoc") .def_property_readonly( "ep_vendor", - [](OrtEpDevice* ep_device) -> std::string { return ep_device->ep_vendor; }, + [](const OrtEpDevice* ep_device) -> std::string { return ep_device->ep_vendor; }, R"pbdoc(The execution provider's vendor name.)pbdoc") .def_property_readonly( "ep_metadata", - [](OrtEpDevice* ep_device) -> std::map { + [](const OrtEpDevice* ep_device) -> std::map { return ep_device->ep_metadata.Entries(); }, R"pbdoc(The execution provider's additional metadata for the OrtHardwareDevice.)pbdoc") .def_property_readonly( "ep_options", - [](OrtEpDevice* ep_device) -> std::map { + [](const OrtEpDevice* ep_device) -> std::map { return ep_device->ep_options.Entries(); }, R"pbdoc(The execution provider's options used to configure the provider to use the OrtHardwareDevice.)pbdoc") .def_property_readonly( "device", - [](OrtEpDevice* ep_device) -> const OrtHardwareDevice& { + [](const OrtEpDevice* ep_device) -> const OrtHardwareDevice& { return *ep_device->device; }, R"pbdoc(The OrtHardwareDevice instance for the OrtEpDevice.)pbdoc", - py::return_value_policy::reference_internal); + py::return_value_policy::reference_internal) + .def( + "memory_info", + [](const OrtEpDevice* ep_device, OrtDeviceMemoryType memory_type) -> const OrtMemoryInfo* { + Ort::ConstEpDevice ep_dev(ep_device); + return static_cast(ep_dev.GetMemoryInfo(memory_type)); + }, + R"pbdoc(The OrtMemoryInfo instance for the OrtEpDevice specific to the device memory type.)pbdoc", + py::return_value_policy::reference_internal) + .def( + "create_sync_stream", + [](const OrtEpDevice* ep_device) -> std::unique_ptr { + Ort::ConstEpDevice ep_dev(ep_device); + Ort::SyncStream stream = ep_dev.CreateSyncStream(); + return std::unique_ptr(stream.release()); + }, + R"pbdoc(The OrtSyncStream instance for the OrtEpDevice.)pbdoc"); py::class_ ort_arena_cfg_binding(m, "OrtArenaCfg"); // Note: Doesn't expose initial_growth_chunk_sizes_bytes/max_power_of_two_extend_bytes option. @@ -1938,25 +1984,28 @@ for model inference.)pbdoc"); .def_readwrite("max_power_of_two_extend_bytes", &OrtArenaCfg::max_power_of_two_extend_bytes); py::class_ ort_memory_info_binding(m, "OrtMemoryInfo"); - ort_memory_info_binding.def(py::init([](const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) { - if (strcmp(name, onnxruntime::CPU) == 0) { - return std::make_unique(onnxruntime::CPU, type, OrtDevice(), mem_type); - } else if (strcmp(name, onnxruntime::CUDA) == 0) { - return std::make_unique( - onnxruntime::CUDA, type, - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, - static_cast(id)), - mem_type); - } else if (strcmp(name, onnxruntime::CUDA_PINNED) == 0) { - return std::make_unique( - onnxruntime::CUDA_PINNED, type, - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::NVIDIA, - static_cast(id)), - mem_type); - } else { - throw std::runtime_error("Specified device is not supported."); - } - })); + ort_memory_info_binding.def( + py::init([](const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) { + Ort::MemoryInfo result(name, type, id, mem_type); + return std::unique_ptr(result.release()); + })) + .def_static( + "create_v2", + [](const char* name, OrtMemoryInfoDeviceType device_type, uint32_t vendor_id, + int32_t device_id, OrtDeviceMemoryType device_mem_type, size_t alignment, OrtAllocatorType type) { + Ort::MemoryInfo result(name, device_type, vendor_id, device_id, device_mem_type, alignment, type); + return std::unique_ptr(result.release()); + }, + R"pbdoc(Create an OrtMemoryInfo instance using CreateMemoryInfo_V2())pbdoc") + .def_property_readonly("name", [](const OrtMemoryInfo* mem_info) -> std::string { return mem_info->name; }, R"pbdoc(Arbitrary name supplied by the user)pbdoc") + .def_property_readonly("device_id", [](const OrtMemoryInfo* mem_info) -> int { return mem_info->device.Id(); }, R"pbdoc(Device Id.)pbdoc") + .def_property_readonly("mem_type", [](const OrtMemoryInfo* mem_info) -> OrtMemType { return mem_info->mem_type; }, R"pbdoc(OrtMemoryInfo memory type.)pbdoc") + .def_property_readonly("allocator_type", [](const OrtMemoryInfo* mem_info) -> OrtAllocatorType { return mem_info->alloc_type; }, R"pbdoc(Allocator type)pbdoc") + .def_property_readonly("device_mem_type", [](const OrtMemoryInfo* mem_info) -> OrtDeviceMemoryType { + auto mem_type = mem_info->device.MemType(); + return (mem_type == OrtDevice::MemType::DEFAULT) ? + OrtDeviceMemoryType_DEFAULT: OrtDeviceMemoryType_HOST_ACCESSIBLE ; }, R"pbdoc(Device memory type (Device or Host accessible).)pbdoc") + .def_property_readonly("device_vendor_id", [](const OrtMemoryInfo* mem_info) -> uint32_t { return mem_info->device.Vendor(); }); py::class_ sess(m, "SessionOptions", R"pbdoc(Configuration information for a session.)pbdoc"); @@ -2653,6 +2702,33 @@ including arg name, arg type (contains both type and shape).)pbdoc") auto res = sess->GetSessionHandle()->GetModelMetadata(); OrtPybindThrowIfError(res.first); return *(res.second); }, py::return_value_policy::reference_internal) + .def_property_readonly("input_meminfos", [](const PyInferenceSession* sess) -> py::list { + Ort::ConstSession session(reinterpret_cast(sess->GetSessionHandle())); + auto inputs_mem_info = session.GetMemoryInfoForInputs(); + py::list result; + for (const auto& info : inputs_mem_info) { + const auto* p_info = static_cast(info); + result.append(py::cast(p_info, py::return_value_policy::reference)); + } + return result; }) + .def_property_readonly("output_meminfos", [](const PyInferenceSession* sess) -> py::list { + Ort::ConstSession session(reinterpret_cast(sess->GetSessionHandle())); + auto outputs_mem_info = session.GetMemoryInfoForOutputs(); + py::list result; + for (const auto& info : outputs_mem_info) { + const auto* p_info = static_cast(info); + result.append(py::cast(p_info, py::return_value_policy::reference)); + } + return result; }) + .def_property_readonly("input_epdevices", [](const PyInferenceSession* sess) -> py::list { + Ort::ConstSession session(reinterpret_cast(sess->GetSessionHandle())); + auto ep_devices = session.GetEpDeviceForInputs(); + py::list result; + for (const auto& device : ep_devices) { + const auto* p_device = static_cast(device); + result.append(py::cast(p_device, py::return_value_policy::reference)); + } + return result; }) .def("run_with_iobinding", [](PyInferenceSession* sess, SessionIOBinding& io_binding, RunOptions* run_options = nullptr) -> void { Status status; diff --git a/onnxruntime/test/autoep/library/ep.cc b/onnxruntime/test/autoep/library/ep.cc index e4265713d2d0a..5d8245618dcd6 100644 --- a/onnxruntime/test/autoep/library/ep.cc +++ b/onnxruntime/test/autoep/library/ep.cc @@ -233,7 +233,7 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG for (const auto& node : nodes) { auto op_type = node.GetOperatorType(); - if (op_type != "Mul") { + if (op_type == "Mul") { // Check that Mul has inputs/output of type float std::vector inputs = node.GetInputs(); std::vector outputs = node.GetOutputs(); @@ -248,11 +248,36 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG continue; // Input or output is not of type float } + { + const auto input_0_shape = GetTensorShape(inputs[0]), + input_1_shape = GetTensorShape(inputs[1]); + + if (!input_0_shape.has_value() || !input_1_shape.has_value()) { + continue; // unable to get input shape + } + + const auto is_static_shape = [](gsl::span shape) -> bool { + return std::all_of(shape.begin(), shape.end(), [](int64_t dim) { return dim >= 0; }); + }; + + if (!is_static_shape(*input_0_shape) || !is_static_shape(*input_1_shape)) { + continue; // input shape has dynamic dimensions + } + + if (*input_0_shape != *input_1_shape) { + continue; // input shapes do not match (no broadcasting support for now) + } + } + supported_nodes.push_back(node); // Only support a single Mul for now. break; } } + if (supported_nodes.empty()) { + return nullptr; + } + // Create (optional) fusion options for the supported nodes to fuse. OrtNodeFusionOptions node_fusion_options = {}; node_fusion_options.ort_version_supported = ORT_API_VERSION; @@ -317,7 +342,7 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const Ort::ConstNode fused_node{fused_nodes[0]}; auto ep_name = fused_node.GetEpName(); - if (ep_name != "example_ep") { + if (ep_name != ep->name_) { Ort::Status status("The fused node is expected to assigned to this EP to run on", ORT_EP_FAIL); return status.release(); } diff --git a/onnxruntime/test/autoep/library/ep_stream_support.cc b/onnxruntime/test/autoep/library/ep_stream_support.cc index 1f6c16a8cb358..c648474d4fad7 100644 --- a/onnxruntime/test/autoep/library/ep_stream_support.cc +++ b/onnxruntime/test/autoep/library/ep_stream_support.cc @@ -61,7 +61,12 @@ OrtStatus* ORT_API_CALL NotificationImpl::ActivateImpl(_In_ OrtSyncNotificationI /*static*/ OrtStatus* ORT_API_CALL NotificationImpl::WaitOnDeviceImpl(_In_ OrtSyncNotificationImpl* this_ptr, _In_ OrtSyncStream* stream) noexcept { + if (stream == nullptr) { + return nullptr; + } + auto& impl = *static_cast(this_ptr); + void* handle = impl.ort_api.SyncStream_GetHandle(stream); static_cast(handle); diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_utils.cc b/onnxruntime/test/autoep/library/example_plugin_ep_utils.cc index 263b4d208bd91..8b36f5f4e9a13 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_utils.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep_utils.cc @@ -35,3 +35,14 @@ void IsFloatTensor(Ort::ConstValueInfo value_info, bool& result) { } result = true; } + +std::optional> GetTensorShape(Ort::ConstValueInfo value_info) { + const auto type_info = value_info.TypeInfo(); + const auto onnx_type = type_info.GetONNXType(); + if (onnx_type != ONNX_TYPE_TENSOR) { + return std::nullopt; + } + + const auto type_shape = type_info.GetTensorTypeAndShapeInfo(); + return type_shape.GetShape(); +} diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_utils.h b/onnxruntime/test/autoep/library/example_plugin_ep_utils.h index e8c086d38a7cb..decc89251dc7b 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_utils.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep_utils.h @@ -4,7 +4,9 @@ #pragma once #include +#include #include +#include #define ORT_API_MANUAL_INIT #include "onnxruntime_cxx_api.h" @@ -108,3 +110,6 @@ OrtStatus* GetSessionConfigEntryOrDefault(const OrtApi& ort_api, const OrtSessio // Returns true (via output parameter) if the given OrtValueInfo represents a float tensor. void IsFloatTensor(Ort::ConstValueInfo value_info, bool& result); + +// Gets the tensor shape from `value_info`. Returns std::nullopt if `value_info` is not a tensor. +std::optional> GetTensorShape(Ort::ConstValueInfo value_info); diff --git a/onnxruntime/test/autoep/test_execution.cc b/onnxruntime/test/autoep/test_execution.cc index 0f4a654f116c4..78be22d082692 100644 --- a/onnxruntime/test/autoep/test_execution.cc +++ b/onnxruntime/test/autoep/test_execution.cc @@ -22,7 +22,8 @@ namespace onnxruntime { namespace test { namespace { -void RunModelWithPluginEp(Ort::SessionOptions& session_options) { + +void RunMulModelWithPluginEp(const Ort::SessionOptions& session_options) { Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), session_options); // Create input @@ -47,6 +48,38 @@ void RunModelWithPluginEp(Ort::SessionOptions& session_options) { gsl::span output_span(output_data, 6); EXPECT_THAT(output_span, ::testing::ElementsAre(2, 4, 6, 8, 10, 12)); } + +void RunPartiallySupportedModelWithPluginEp(const Ort::SessionOptions& session_options) { + // This model has Add -> Mul -> Add. The example plugin EP only supports Mul. + Ort::Session session(*ort_env, ORT_TSTR("testdata/add_mul_add.onnx"), session_options); + + // Create inputs + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + std::vector shape = {3, 2}; + + std::vector a_data{1, 2, 3, 4, 5, 6}; + std::vector b_data{2, 3, 4, 5, 6, 7}; + + std::vector ort_inputs{}; + ort_inputs.emplace_back( + Ort::Value::CreateTensor(memory_info, a_data.data(), a_data.size(), shape.data(), shape.size())); + ort_inputs.emplace_back( + Ort::Value::CreateTensor(memory_info, b_data.data(), b_data.size(), shape.data(), shape.size())); + + std::array ort_input_names{"A", "B"}; + + // Run session and get outputs + std::array output_names{"C"}; + std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), + ort_inputs.size(), output_names.data(), output_names.size()); + + // Check expected output values + Ort::Value& ort_output = ort_outputs[0]; + const float* output_data = ort_output.GetTensorData(); + gsl::span output_span(output_data, 6); + EXPECT_THAT(output_span, ::testing::ElementsAre(7, 17, 31, 49, 71, 97)); +} + } // namespace // Creates a session with the example plugin EP and runs a model with a single Mul node. @@ -61,7 +94,7 @@ TEST(OrtEpLibrary, PluginEp_AppendV2_MulInference) { std::unordered_map ep_options; session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); - RunModelWithPluginEp(session_options); + RunMulModelWithPluginEp(session_options); } // Creates a session with the example plugin EP and runs a model with a single Mul node. @@ -74,10 +107,23 @@ TEST(OrtEpLibrary, PluginEp_PreferCpu_MulInference) { // PREFER_CPU pick our example EP over ORT CPU EP. TODO: Actually assert this. Ort::SessionOptions session_options; session_options.SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy_PREFER_CPU); - RunModelWithPluginEp(session_options); + RunMulModelWithPluginEp(session_options); } } +TEST(OrtEpLibrary, PluginEp_AppendV2_PartiallySupportedModelInference) { + RegisteredEpDeviceUniquePtr example_ep; + Utils::RegisterAndGetExampleEp(*ort_env, example_ep); + Ort::ConstEpDevice plugin_ep_device(example_ep.get()); + + // Create session with example plugin EP + Ort::SessionOptions session_options; + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + + RunPartiallySupportedModelWithPluginEp(session_options); +} + // Generate an EPContext model with a plugin EP. // This test uses the OrtCompileApi but could also be done by setting the appropriate session option configs. TEST(OrtEpLibrary, PluginEp_GenEpContextModel) { @@ -98,6 +144,7 @@ TEST(OrtEpLibrary, PluginEp_GenEpContextModel) { // Create model compilation options from the session options. Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + compile_options.SetFlags(OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED); compile_options.SetInputModelPath(input_model_file); compile_options.SetOutputModelPath(output_model_file); diff --git a/onnxruntime/test/contrib_ops/moe_test.cc b/onnxruntime/test/contrib_ops/moe_test.cc index 0690b8894eb7a..ab740ea38fb74 100644 --- a/onnxruntime/test/contrib_ops/moe_test.cc +++ b/onnxruntime/test/contrib_ops/moe_test.cc @@ -1690,6 +1690,97 @@ TEST(MoETest, QMoETest_CPU_SwiGLU_Int8) { #endif } +// Test for CPU MoE implementation +static void RunMoECpuTest(const std::vector& input, const std::vector& router_probs, + const std::vector& fc1_experts_weights, const std::vector& fc2_experts_weights, + const std::vector& fc3_experts_weights, const std::vector& fc1_experts_bias, + const std::vector& fc2_experts_bias, const std::vector& output_data, int num_rows, + int num_experts, int hidden_size, int inter_size, std::string activation_type, + int normalize_routing_weights = 1, int top_k = 1) { + OpTester tester("MoE", 1, onnxruntime::kMSDomain); + tester.AddAttribute("k", static_cast(top_k)); + tester.AddAttribute("activation_type", activation_type); + tester.AddAttribute("normalize_routing_weights", static_cast(normalize_routing_weights)); + + bool is_swiglu = (activation_type == "swiglu"); + + if (is_swiglu) { + tester.AddAttribute("swiglu_fusion", static_cast(1)); + tester.AddAttribute("activation_beta", 1.0f); + } + + std::vector input_dims = {num_rows, hidden_size}; + std::vector router_probs_dims = {num_rows, num_experts}; + + int64_t fc1_output_size = is_swiglu ? (2 * inter_size) : inter_size; + + std::vector fc1_experts_weights_dims = {num_experts, hidden_size, fc1_output_size}; + std::vector fc2_experts_weights_dims = {num_experts, inter_size, hidden_size}; + std::vector fc3_experts_weights_dims = fc1_experts_weights_dims; + std::vector fc1_experts_bias_dims = {num_experts, fc1_output_size}; + std::vector fc2_experts_bias_dims = {num_experts, hidden_size}; + std::vector output_dims = {num_rows, hidden_size}; + + tester.AddInput("input", input_dims, input); + tester.AddInput("router_probs", router_probs_dims, router_probs); + tester.AddInput("fc1_experts_weights", fc1_experts_weights_dims, fc1_experts_weights); + if (!fc1_experts_bias.empty()) { + tester.AddInput("fc1_experts_bias", fc1_experts_bias_dims, fc1_experts_bias); + } else { + tester.AddOptionalInputEdge(); + } + tester.AddInput("fc2_experts_weights", fc2_experts_weights_dims, fc2_experts_weights); + if (!fc2_experts_bias.empty()) { + tester.AddInput("fc2_experts_bias", fc2_experts_bias_dims, fc2_experts_bias); + } else { + tester.AddOptionalInputEdge(); + } + if (!fc3_experts_weights.empty()) { + tester.AddInput("fc3_experts_weights", fc3_experts_weights_dims, fc3_experts_weights); + } else { + tester.AddOptionalInputEdge(); + } + tester.AddOptionalInputEdge(); // fc3_experts_bias + + tester.AddOutput("output", output_dims, output_data); + tester.SetOutputTolerance(0.05f); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(MoETest, MoECpuTest_BasicSwiGLU) { + int num_rows = 2; + int num_experts = 2; + int hidden_size = 4; + int inter_size = 8; + + // Simple test data + const std::vector input = { + 1.0f, 2.0f, 3.0f, 4.0f, + 5.0f, 6.0f, 7.0f, 8.0f}; + + const std::vector router_probs = { + 0.8f, 0.2f, + 0.3f, 0.7f}; + + const std::vector fc1_experts_weights(num_experts * hidden_size * (2 * inter_size), 0.1f); + + const std::vector fc2_experts_weights(num_experts * inter_size * hidden_size, 0.1f); + + const std::vector fc3_experts_weights = {}; // No FC3 + const std::vector fc1_experts_bias = {}; // No bias + const std::vector fc2_experts_bias = {}; // No bias + + const std::vector output_data = { + 1.169694f, 1.169694f, 1.169694f, 1.169694f, + 6.970291f, 6.970291f, 6.970291f, 6.970291f}; + + RunMoECpuTest(input, router_probs, fc1_experts_weights, fc2_experts_weights, + fc3_experts_weights, fc1_experts_bias, fc2_experts_bias, output_data, + num_rows, num_experts, hidden_size, inter_size, "swiglu"); +} #endif } // namespace test diff --git a/onnxruntime/test/framework/ep_plugin_provider_test.cc b/onnxruntime/test/framework/ep_plugin_provider_test.cc index 35f7d06fb0912..30595d5ce97b2 100644 --- a/onnxruntime/test/framework/ep_plugin_provider_test.cc +++ b/onnxruntime/test/framework/ep_plugin_provider_test.cc @@ -3,9 +3,14 @@ #include "core/session/plugin_ep/ep_plugin_provider_interfaces.h" +#include #include "gsl/gsl" #include "gtest/gtest.h" +#include "core/common/logging/sinks/file_sink.h" +#include "core/graph/graph_viewer.h" +#include "core/graph/model.h" +#include "core/optimizer/graph_optimizer_registry.h" #include "core/session/abi_devices.h" #include "core/session/onnxruntime_cxx_api.h" #include "test/util/include/asserts.h" @@ -23,6 +28,14 @@ struct ApiPtrs { const gsl::not_null ep_api; }; +static void CheckStringInFile(const PathString& filename, const std::string& look_for) { + std::ifstream ifs{filename}; + std::string content(std::istreambuf_iterator{ifs}, + std::istreambuf_iterator{}); + + EXPECT_NE(content.find(look_for), std::string::npos); +} + // Normally, a plugin EP would be implemented in a separate library. // The `test_plugin_ep` namespace contains a local implementation intended for unit testing. namespace test_plugin_ep { @@ -114,6 +127,10 @@ MakeTestOrtEpResult MakeTestOrtEp(std::vector ep_devices = { return result; } +class MockKernelLookup : public IExecutionProvider::IKernelLookup { + const KernelCreateInfo* LookUpKernel(const Node& /*node*/) const override { return nullptr; } +}; + } // namespace test_plugin_ep TEST(PluginExecutionProviderTest, GetPreferredLayout) { @@ -317,4 +334,218 @@ TEST(PluginExecutionProviderTest, InferOrtDeviceFromDeviceMemoryInfo) { #endif // !defined(ORT_NO_EXCEPTIONS) } +static void LoadModelAndAssignNodesToEp(const ORTCHAR_T* model_path, + const char* ep_name, + const std::unordered_set& ep_node_names, + /*out*/ std::shared_ptr& model) { + ASSERT_STATUS_OK(Model::Load(model_path, model, nullptr, + DefaultLoggingManager().DefaultLogger())); + + Graph& graph = model->MainGraph(); + + for (Node& node : graph.Nodes()) { + if (ep_node_names.count(node.Name()) > 0) { + node.SetExecutionProviderType(ep_name); + } + } +} + +static OrtStatus* ORT_API_CALL GetCapabilityTakeAllNodesOneGroup(OrtEp* this_ptr, const OrtGraph* graph, + OrtEpGraphSupportInfo* graph_support_info) noexcept { + auto* this_ep = static_cast(this_ptr); + + size_t num_nodes = 0; + if (OrtStatus* st = this_ep->ort_api->Graph_GetNumNodes(graph, &num_nodes); st != nullptr) { + return st; + } + + std::vector nodes(num_nodes); + if (OrtStatus* st = this_ep->ort_api->Graph_GetNodes(graph, nodes.data(), nodes.size()); st != nullptr) { + return st; + } + + if (OrtStatus* st = this_ep->ep_api->EpGraphSupportInfo_AddNodesToFuse(graph_support_info, + nodes.data(), nodes.size(), nullptr); + st != nullptr) { + return st; + } + + return nullptr; +} + +static OrtStatus* ORT_API_CALL GetCapabilityTakeAllNodesTwoGroups(OrtEp* this_ptr, const OrtGraph* graph, + OrtEpGraphSupportInfo* graph_support_info) noexcept { + auto* this_ep = static_cast(this_ptr); + + size_t num_nodes = 0; + if (OrtStatus* st = this_ep->ort_api->Graph_GetNumNodes(graph, &num_nodes); st != nullptr) { + return st; + } + + std::vector nodes(num_nodes); + if (OrtStatus* st = this_ep->ort_api->Graph_GetNodes(graph, nodes.data(), nodes.size()); st != nullptr) { + return st; + } + + // Expect at least 2 nodes. If not, this is really a testing/setup error. + if (num_nodes < 2) { + return this_ep->ort_api->CreateStatus(OrtErrorCode::ORT_FAIL, + "Expected at least two nodes in call to GetCapability"); + } + + std::vector node_group1; + std::vector node_group2; + + for (size_t i = 0; i < num_nodes; i++) { + if (i < num_nodes / 2) { + node_group1.push_back(nodes[i]); + } else { + node_group2.push_back(nodes[i]); + } + } + + if (OrtStatus* st = this_ep->ep_api->EpGraphSupportInfo_AddNodesToFuse(graph_support_info, + node_group1.data(), node_group1.size(), + nullptr); + st != nullptr) { + return st; + } + + if (OrtStatus* st = this_ep->ep_api->EpGraphSupportInfo_AddNodesToFuse(graph_support_info, + node_group2.data(), node_group2.size(), + nullptr); + st != nullptr) { + return st; + } + + return nullptr; +} + +static OrtStatus* ORT_API_CALL GetCapabilityTakeSingleNode(OrtEp* this_ptr, const OrtGraph* graph, + OrtEpGraphSupportInfo* graph_support_info) noexcept { + auto* this_ep = static_cast(this_ptr); + + size_t num_nodes = 0; + if (OrtStatus* st = this_ep->ort_api->Graph_GetNumNodes(graph, &num_nodes); st != nullptr) { + return st; + } + + std::vector nodes(num_nodes); + if (OrtStatus* st = this_ep->ort_api->Graph_GetNodes(graph, nodes.data(), nodes.size()); st != nullptr) { + return st; + } + + // Take only the first node using EpGraphSupportInfo_AddSingleNode(). + if (OrtStatus* st = this_ep->ep_api->EpGraphSupportInfo_AddSingleNode(graph_support_info, nodes[0]); + st != nullptr) { + return st; + } + + return nullptr; +} + +// Tests that GetCapability() doesn't crash if a plugin EP tries to claim a mix of unassigned nodes and +// nodes that are already assigned to another EP. +TEST(PluginExecutionProviderTest, GetCapability_ClaimNodesAssignedToOtherEP) { + std::filesystem::path log_file = ORT_TSTR("log_get_capability.txt"); + + // Helper function that loads a model (Add -> Mul -> Add) and assigns some or all of the nodes to another EP. + // Then, IExecutionProvider::GetCapability() is called to test the expected behavior. + auto run_test = [&log_file](IExecutionProvider& ep, + const std::unordered_set& nodes_for_other_ep, + const std::unordered_set& nodes_for_this_ep, + const char* expected_log_string) { + std::shared_ptr model; + ASSERT_NO_FATAL_FAILURE(LoadModelAndAssignNodesToEp(ORT_TSTR("testdata/add_mul_add.onnx"), + "OtherEp", nodes_for_other_ep, model)); + + std::filesystem::remove(log_file); + + // Call IExecutionProvider::GetCapability and check results + logs. + { + logging::LoggingManager log_manager{std::make_unique(log_file, false, false), + logging::Severity::kWARNING, false, + logging::LoggingManager::InstanceType::Temporal}; + auto file_logger = log_manager.CreateLogger("FileLogger"); + ep.SetLogger(file_logger.get()); // Make EP log to a file. + + GraphViewer graph_viewer(model->MainGraph()); + auto compute_capabilities = ep.GetCapability(graph_viewer, + test_plugin_ep::MockKernelLookup{}, + GraphOptimizerRegistry(nullptr, nullptr, file_logger.get()), + nullptr); + + ASSERT_EQ(compute_capabilities.size(), nodes_for_this_ep.empty() ? 0 : 1); + + if (compute_capabilities.size() == 1) { + ASSERT_EQ(compute_capabilities[0]->sub_graph->nodes.size(), nodes_for_this_ep.size()); + + for (NodeIndex node_index : compute_capabilities[0]->sub_graph->nodes) { + const Node* node = graph_viewer.GetNode(node_index); + ASSERT_NE(node, nullptr); + EXPECT_EQ(nodes_for_this_ep.count(node->Name()), 1); + } + } + } + + ASSERT_TRUE(std::filesystem::exists(log_file)); + EXPECT_NO_FATAL_FAILURE(CheckStringInFile(log_file, expected_log_string)); + }; + + constexpr std::array node_names = {"add_0", "mul_0", "add_1"}; + + auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(); + + // Load a model and assign all of its nodes to another EP named 'OtherEp'. + // The plugin EP tries to claim all nodes in a single group via EpGraphSupportInfo_AddNodesToFuse. + // IExecutionProvider::GetCapability() should return an empty result and log a warning. + ort_ep->GetCapability = GetCapabilityTakeAllNodesOneGroup; + std::unordered_set nodes_for_other_ep = {"add_0", "mul_0", "add_1"}; + std::unordered_set nodes_for_this_ep; + run_test(*ep, nodes_for_other_ep, nodes_for_this_ep, + "Found one or more nodes that were already assigned to a different EP named 'OtherEp'"); + + // Load a model and assign only one node to another EP named 'OtherEp'. + // The plugin EP tries to claim all nodes in a single group. + // IExecutionProvider::GetCapability() should return an empty result and log a warning. + ort_ep->GetCapability = GetCapabilityTakeAllNodesOneGroup; + for (const char* node_name : node_names) { + nodes_for_other_ep = std::unordered_set{node_name}; + nodes_for_this_ep = std::unordered_set{}; + run_test(*ep, nodes_for_other_ep, nodes_for_this_ep, + "Found one or more nodes that were already assigned to a different EP named 'OtherEp'"); + } + + // Load a model and assign only the last Add node to another EP named 'OtherEp'. + // The plugin EP tries to claim all nodes in the following 2 groups: (add_0), (mul_0, add_1). + // IExecutionProvider::GetCapability() will only return (add_0) because the second group has a node + // that was assigned to 'OtherEp'. + ort_ep->GetCapability = GetCapabilityTakeAllNodesTwoGroups; + nodes_for_other_ep = std::unordered_set{"add_1"}; + nodes_for_this_ep = std::unordered_set{"add_0"}; + run_test(*ep, nodes_for_other_ep, nodes_for_this_ep, + "Found one or more nodes that were already assigned to a different EP named 'OtherEp'"); + + // Load a model and assign only the first Add node to another EP named 'OtherEp'. + // The plugin EP tries to claim all nodes in the following 2 groups: (add_0), (mul_0, add_1). + // IExecutionProvider::GetCapability() will only return (mul_0, add_1) because the first group has a node + // that was assigned to 'OtherEp'. + ort_ep->GetCapability = GetCapabilityTakeAllNodesTwoGroups; + nodes_for_other_ep = std::unordered_set{"add_0"}; + nodes_for_this_ep = std::unordered_set{"mul_0", "add_1"}; + run_test(*ep, nodes_for_other_ep, nodes_for_this_ep, + "Found one or more nodes that were already assigned to a different EP named 'OtherEp'"); + + // Load a model and assign the first Add node to another EP named 'OtherEp'. + // The plugin EP will try to take only the first Add node with a single call to EpGraphSupportInfo_AddSingleNode. + // IExecutionProvider::GetCapability() will return an empty result and log a warning. + ort_ep->GetCapability = GetCapabilityTakeSingleNode; + nodes_for_other_ep = std::unordered_set{"add_0"}; + nodes_for_this_ep = std::unordered_set{}; + run_test(*ep, nodes_for_other_ep, nodes_for_this_ep, + "Found one or more nodes that were already assigned to a different EP named 'OtherEp'"); + + std::filesystem::remove(log_file); +} + } // namespace onnxruntime::test diff --git a/onnxruntime/test/framework/memcpy_transformer_test.cc b/onnxruntime/test/framework/memcpy_transformer_test.cc index 6e86e5b58aead..01b253446974b 100644 --- a/onnxruntime/test/framework/memcpy_transformer_test.cc +++ b/onnxruntime/test/framework/memcpy_transformer_test.cc @@ -71,8 +71,18 @@ void ExpectCopy(const onnxruntime::Node& source, const std::string copy_op, } EXPECT_TRUE(false) << "Copy node expected but not found"; } + #ifdef USE_CUDA +static InlinedVector> GetNotNullProviderPtrs( + const ExecutionProviders& providers) { + InlinedVector> not_null_provider_ptrs{}; + for (auto& provider_ptr : providers) { + not_null_provider_ptrs.emplace_back(provider_ptr.get()); + } + return not_null_provider_ptrs; +} + TEST(TransformerTest, MemcpyTransformerTest) { std::unordered_map domain_to_version; domain_to_version[kOnnxDomain] = 7; @@ -112,7 +122,11 @@ TEST(TransformerTest, MemcpyTransformerTest) { KernelRegistryManager test_registry_manager; ASSERT_STATUS_OK(test_registry_manager.RegisterKernels(execution_providers)); - MemcpyTransformer transformer({onnxruntime::kCudaExecutionProvider}, test_registry_manager); + InlinedVector> providers; + for (auto& provider_ptr : execution_providers) { + providers.push_back(provider_ptr.get()); + } + MemcpyTransformer transformer(GetNotNullProviderPtrs(execution_providers), test_registry_manager); bool modified = false; status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()); @@ -167,7 +181,7 @@ TEST(TransformerTest, MemcpyTransformerTestCudaFirst) { KernelRegistryManager test_registry_manager; ASSERT_STATUS_OK(test_registry_manager.RegisterKernels(execution_providers)); - MemcpyTransformer transformer({onnxruntime::kCudaExecutionProvider}, test_registry_manager); + MemcpyTransformer transformer(GetNotNullProviderPtrs(execution_providers), test_registry_manager); bool modified = false; status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()); @@ -262,6 +276,8 @@ TEST(TransformerTest, TestInitializerDuplicationInSubgraph) { if_node.AddAttribute("then_branch", subgraph.ToGraphProto()); if_node.AddAttribute("else_branch", subgraph.ToGraphProto()); + if_node.SetExecutionProviderType(onnxruntime::kCpuExecutionProvider); + onnxruntime::Graph* subgraph_1 = if_node.GetMutableGraphAttribute("then_branch"); for (auto& node : subgraph_1->Nodes()) { if (node.Name() == "node2") { @@ -287,7 +303,7 @@ TEST(TransformerTest, TestInitializerDuplicationInSubgraph) { KernelRegistryManager test_registry_manager; ASSERT_STATUS_OK(test_registry_manager.RegisterKernels(execution_providers)); - MemcpyTransformer transformer({onnxruntime::kCudaExecutionProvider}, test_registry_manager); + MemcpyTransformer transformer(GetNotNullProviderPtrs(execution_providers), test_registry_manager); bool modified = false; ASSERT_STATUS_OK(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger())); @@ -329,7 +345,7 @@ TEST(TransformerTest, MemcpyTransformerTestGraphInputConsumedOnMultipleDevices) KernelRegistryManager test_registry_manager; ASSERT_STATUS_OK(test_registry_manager.RegisterKernels(execution_providers)); - MemcpyTransformer transformer({onnxruntime::kCudaExecutionProvider}, test_registry_manager); + MemcpyTransformer transformer(GetNotNullProviderPtrs(execution_providers), test_registry_manager); bool modified = false; status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()); @@ -398,6 +414,8 @@ TEST(TransformerTest, MemcpyTransformerTestImplicitInputConsumedOnMultipleDevice if_node.AddAttribute("then_branch", subgraph.ToGraphProto()); if_node.AddAttribute("else_branch", subgraph.ToGraphProto()); + if_node.SetExecutionProviderType(onnxruntime::kCpuExecutionProvider); + graph.SetInputs({&i1_def, &i2_def}); onnxruntime::Graph* subgraph_1 = if_node.GetMutableGraphAttribute("then_branch"); @@ -431,7 +449,7 @@ TEST(TransformerTest, MemcpyTransformerTestImplicitInputConsumedOnMultipleDevice KernelRegistryManager test_registry_manager; ASSERT_STATUS_OK(test_registry_manager.RegisterKernels(execution_providers)); - MemcpyTransformer transformer({onnxruntime::kCudaExecutionProvider}, test_registry_manager); + MemcpyTransformer transformer(GetNotNullProviderPtrs(execution_providers), test_registry_manager); bool modified = false; status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()); diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index bef0bdd5295be..d56212510d2a9 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -795,6 +795,24 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); // Please make no more changes to the list static const ORTCHAR_T* immutable_broken_tests[] = { + // pending ONNX update + ORT_TSTR("attention_3d_gqa"), + ORT_TSTR("attention_3d_gqa_attn_mask"), + ORT_TSTR("attention_3d_gqa_causal"), + ORT_TSTR("attention_3d_gqa_scaled"), + ORT_TSTR("attention_3d_gqa_softcap"), + ORT_TSTR("attention_3d_gqa_with_past_and_present"), + ORT_TSTR("attention_4d_gqa"), + ORT_TSTR("attention_4d_gqa_attn_mask"), + ORT_TSTR("attention_4d_gqa_causal"), + ORT_TSTR("attention_4d_gqa_scaled"), + ORT_TSTR("attention_4d_gqa_softcap"), + ORT_TSTR("attention_4d_gqa_with_past_and_present"), + ORT_TSTR("attention_4d_diff_heads_mask4d_padded_kv"), + ORT_TSTR("attention_4d_gqa_with_past_and_present_fp16"), + ORT_TSTR("attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_causal"), + ORT_TSTR("attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal"), + // unsupported case ORT_TSTR("AvgPool1d"), ORT_TSTR("AvgPool1d_stride"), ORT_TSTR("AvgPool2d"), diff --git a/onnxruntime/test/optimizer/transpose_optimizer_test.cc b/onnxruntime/test/optimizer/transpose_optimizer_test.cc index f6fce37322c10..c382612a6dff8 100644 --- a/onnxruntime/test/optimizer/transpose_optimizer_test.cc +++ b/onnxruntime/test/optimizer/transpose_optimizer_test.cc @@ -4518,7 +4518,7 @@ TEST(TransposeOptimizerTests, QnnTransposeReshape) { // changes during the layout transformation process. ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kDebugLayoutTransformation, "1")); - using InternalTestingEP = onnxruntime::internal_testing_ep::InternalTestingExecutionProvider; + using InternalTestingEP = internal_testing_ep::InternalTestingExecutionProvider; // set the test EP to support all ops in the model so that the layout transform applies to all nodes const std::unordered_set empty_set; @@ -4544,7 +4544,7 @@ TEST(TransposeOptimizerTests, QnnTransposeReshape) { "with the exception of the initial node prior to the Conv"; // all nodes should be assigned to the internal testing EP, which also means they should be in NHWC layout - std::string expected_ep(onnxruntime::utils::kInternalTestingExecutionProvider); + std::string expected_ep(internal_testing_ep::kInternalTestingExecutionProvider); for (const auto& node : graph.Nodes()) { EXPECT_TRUE(node.GetExecutionProviderType() == expected_ep) << node.OpType() << " node named '" << node.Name() << "' was not assigned to the internal testing EP."; @@ -4564,7 +4564,7 @@ TEST(TransposeOptimizerTests, QnnTransposeReshapeQDQ) { SessionOptions so; - using InternalTestingEP = onnxruntime::internal_testing_ep::InternalTestingExecutionProvider; + using InternalTestingEP = internal_testing_ep::InternalTestingExecutionProvider; // set the test EP to support all ops in the model so that the layout transform applies to all nodes const std::unordered_set empty_set; @@ -4590,7 +4590,7 @@ TEST(TransposeOptimizerTests, QnnTransposeReshapeQDQ) { "with the exception of the initial node prior to the Conv"; // all nodes should be assigned to the internal testing EP, which also means they should be in NHWC layout - std::string expected_ep(onnxruntime::utils::kInternalTestingExecutionProvider); + std::string expected_ep(internal_testing_ep::kInternalTestingExecutionProvider); for (const auto& node : graph.Nodes()) { EXPECT_TRUE(node.GetExecutionProviderType() == expected_ep) << node.OpType() << " node named '" << node.Name() << "' was not assigned to the internal testing EP."; @@ -4606,7 +4606,7 @@ TEST(TransposeOptimizerTests, QnnResizeOpset11) { // Uncomment to debug // ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kDebugLayoutTransformation, "1")); - using InternalTestingEP = onnxruntime::internal_testing_ep::InternalTestingExecutionProvider; + using InternalTestingEP = internal_testing_ep::InternalTestingExecutionProvider; // set the test EP to support all ops in the model so that the layout transform applies to all nodes const std::unordered_set empty_set; @@ -4620,7 +4620,7 @@ TEST(TransposeOptimizerTests, QnnResizeOpset11) { const auto& graph = session.GetGraph(); // all nodes should be assigned to the internal testing EP, which also means they should be in NHWC layout - std::string expected_ep(onnxruntime::utils::kInternalTestingExecutionProvider); + std::string expected_ep(internal_testing_ep::kInternalTestingExecutionProvider); for (const auto& node : graph.Nodes()) { EXPECT_TRUE(node.GetExecutionProviderType() == expected_ep) << node.OpType() << " node named '" << node.Name() << "' was not assigned to the internal testing EP."; @@ -4648,7 +4648,7 @@ TEST(TransposeOptimizerTests, QnnTransposeNonConstBroadcastInput) { // ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kDebugLayoutTransformation, "1")); - using InternalTestingEP = onnxruntime::internal_testing_ep::InternalTestingExecutionProvider; + using InternalTestingEP = internal_testing_ep::InternalTestingExecutionProvider; // set the test EP to support all ops in the model so that the layout transform applies to all nodes const std::unordered_set empty_set; @@ -4666,7 +4666,7 @@ TEST(TransposeOptimizerTests, QnnTransposeNonConstBroadcastInput) { ASSERT_EQ(op_to_count["Transpose"], 3) << "Should have Transpose on 2 inputs and one on output."; // all nodes should be assigned to the internal testing EP, which also means they should be in NHWC layout - std::string expected_ep(onnxruntime::utils::kInternalTestingExecutionProvider); + std::string expected_ep(internal_testing_ep::kInternalTestingExecutionProvider); for (const auto& node : graph.Nodes()) { EXPECT_EQ(node.GetExecutionProviderType(), expected_ep) << node.OpType() << " node named '" << node.Name() << "' was not assigned to the internal testing EP."; @@ -4699,7 +4699,7 @@ TEST(TransposeOptimizerTests, LayoutTransformFixStuckTransposeWithoutDQ) { // ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kDebugLayoutTransformation, "1")); - using InternalTestingEP = onnxruntime::internal_testing_ep::InternalTestingExecutionProvider; + using InternalTestingEP = internal_testing_ep::InternalTestingExecutionProvider; // Set the test EP to support all ops in the model so that the layout transform applies to all nodes const std::unordered_set empty_set; @@ -4716,7 +4716,7 @@ TEST(TransposeOptimizerTests, LayoutTransformFixStuckTransposeWithoutDQ) { ASSERT_EQ(op_to_count["Transpose"], 2) << "Should have 2 transposes remaining."; - std::string expected_ep(onnxruntime::utils::kInternalTestingExecutionProvider); + std::string expected_ep(internal_testing_ep::kInternalTestingExecutionProvider); for (const auto& node : graph.Nodes()) { EXPECT_EQ(node.GetExecutionProviderType(), expected_ep) << node.OpType() << " node named '" << node.Name() << "' was not assigned to the internal testing EP."; @@ -4756,7 +4756,7 @@ TEST(TransposeOptimizerTests, LayoutTransformConstantFoldTransposeAndSqueeze) { // ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kDebugLayoutTransformation, "1")); - using InternalTestingEP = onnxruntime::internal_testing_ep::InternalTestingExecutionProvider; + using InternalTestingEP = internal_testing_ep::InternalTestingExecutionProvider; // Set the test EP to support all ops in the model so that the layout transform applies to all nodes const std::unordered_set empty_set; @@ -4777,7 +4777,7 @@ TEST(TransposeOptimizerTests, LayoutTransformConstantFoldTransposeAndSqueeze) { // 1 transpose is constant-folded, 1 is canceled, and 1 remains. ASSERT_EQ(op_to_count["Transpose"], 1) << "Should have 1 transpose remaining."; - std::string expected_ep(onnxruntime::utils::kInternalTestingExecutionProvider); + std::string expected_ep(internal_testing_ep::kInternalTestingExecutionProvider); for (const auto& node : graph.Nodes()) { EXPECT_EQ(node.GetExecutionProviderType(), expected_ep) << node.OpType() << " node named '" << node.Name() << "' was not assigned to the internal testing EP."; diff --git a/onnxruntime/test/providers/cpu/llm/attention_op_test.cc b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc index b4f6d328cacf7..54c2ed7d521db 100644 --- a/onnxruntime/test/providers/cpu/llm/attention_op_test.cc +++ b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc @@ -664,6 +664,7 @@ TEST(AttentionTest, Attention4DAttnPastPresent) { false, true, true // disable_cpu, disable_cuda, disable_dml ); } + TEST(AttentionTest, Attention4DAttnIsCausal) { int batch_size = 2; // Q.shape[0] int q_num_heads = 3; // Q.shape[1] @@ -828,6 +829,38 @@ TEST(AttentionTest, Attention4DDiffHeadsWithPastAndPresent) { ); } +TEST(AttentionTest, Attention3DGqaAttn) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 9; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 8; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 8; // V.shape[3] + int past_sequence_length = 0; // past_key.shape[2] and past_value.shape[2] + + // {2, 4, 72} + std::vector q = {0.548814f, 0.715189f, 0.602763f, 0.544883f, 0.423655f, 0.645894f, 0.437587f, 0.891773f, 0.963663f, 0.383442f, 0.791725f, 0.528895f, 0.568045f, 0.925597f, 0.071036f, 0.087129f, 0.020218f, 0.832620f, 0.778157f, 0.870012f, 0.978618f, 0.799159f, 0.461479f, 0.780529f, 0.118274f, 0.639921f, 0.143353f, 0.944669f, 0.521848f, 0.414662f, 0.264556f, 0.774234f, 0.456150f, 0.568434f, 0.018790f, 0.617635f, 0.612096f, 0.616934f, 0.943748f, 0.681820f, 0.359508f, 0.437032f, 0.697631f, 0.060225f, 0.666767f, 0.670638f, 0.210383f, 0.128926f, 0.315428f, 0.363711f, 0.570197f, 0.438602f, 0.988374f, 0.102045f, 0.208877f, 0.161310f, 0.653108f, 0.253292f, 0.466311f, 0.244426f, 0.158970f, 0.110375f, 0.656330f, 0.138183f, 0.196582f, 0.368725f, 0.820993f, 0.097101f, 0.837945f, 0.096098f, 0.976459f, 0.468651f, 0.976761f, 0.604846f, 0.739264f, 0.039188f, 0.282807f, 0.120197f, 0.296140f, 0.118728f, 0.317983f, 0.414263f, 0.064147f, 0.692472f, 0.566601f, 0.265390f, 0.523248f, 0.093941f, 0.575947f, 0.929296f, 0.318569f, 0.667410f, 0.131798f, 0.716327f, 0.289406f, 0.183191f, 0.586513f, 0.020108f, 0.828940f, 0.004695f, 0.677817f, 0.270008f, 0.735194f, 0.962189f, 0.248753f, 0.576157f, 0.592042f, 0.572252f, 0.223082f, 0.952749f, 0.447125f, 0.846409f, 0.699479f, 0.297437f, 0.813798f, 0.396506f, 0.881103f, 0.581273f, 0.881735f, 0.692532f, 0.725254f, 0.501324f, 0.956084f, 0.643990f, 0.423855f, 0.606393f, 0.019193f, 0.301575f, 0.660174f, 0.290078f, 0.618015f, 0.428769f, 0.135474f, 0.298282f, 0.569965f, 0.590873f, 0.574325f, 0.653201f, 0.652103f, 0.431418f, 0.896547f, 0.367562f, 0.435865f, 0.891923f, 0.806194f, 0.703889f, 0.100227f, 0.919483f, 0.714241f, 0.998847f, 0.149448f, 0.868126f, 0.162493f, 0.615560f, 0.123820f, 0.848008f, 0.807319f, 0.569101f, 0.407183f, 0.069167f, 0.697429f, 0.453543f, 0.722056f, 0.866382f, 0.975522f, 0.855803f, 0.011714f, 0.359978f, 0.729991f, 0.171630f, 0.521037f, 0.054338f, 0.199997f, 0.018522f, 0.793698f, 0.223925f, 0.345352f, 0.928081f, 0.704414f, 0.031839f, 0.164694f, 0.621478f, 0.577229f, 0.237893f, 0.934214f, 0.613966f, 0.535633f, 0.589910f, 0.730122f, 0.311945f, 0.398221f, 0.209844f, 0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f, 0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f}; + // {2, 6, 24} + std::vector k = {0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f, 0.966147f, 0.791880f, 0.675689f, 0.244889f, 0.216457f, 0.166048f, 0.922757f, 0.294077f, 0.453094f, 0.493958f, 0.778172f, 0.844235f, 0.139073f, 0.426904f, 0.842855f, 0.818033f, 0.102414f, 0.156383f, 0.304199f, 0.075359f, 0.424663f, 0.107618f, 0.568218f, 0.246557f, 0.596433f, 0.117526f, 0.975884f, 0.932561f, 0.391797f, 0.242179f, 0.250398f, 0.483394f, 0.039993f, 0.639705f, 0.408303f, 0.377407f, 0.809365f, 0.709035f, 0.954334f, 0.351936f, 0.897543f, 0.769967f, 0.357425f, 0.621665f, 0.288570f, 0.874400f, 0.112427f, 0.212434f, 0.183033f, 0.403026f, 0.745233f, 0.526907f, 0.487676f, 0.000546f, 0.425402f, 0.063554f, 0.208253f, 0.932394f, 0.215398f, 0.858338f, 0.802893f, 0.159146f, 0.605712f, 0.115662f, 0.727888f, 0.637462f, 0.811939f, 0.479385f, 0.914863f, 0.049349f, 0.292889f, 0.715053f, 0.418109f, 0.172951f, 0.107211f, 0.817339f, 0.473143f, 0.882284f, 0.733289f, 0.409726f, 0.373511f, 0.515638f, 0.889060f, 0.737279f, 0.005153f, 0.694158f, 0.919507f, 0.710456f, 0.177006f, 0.483518f, 0.140316f, 0.358995f, 0.937117f, 0.923305f, 0.282837f, 0.339631f}; + // {2, 6, 24} + std::vector v = {0.600213f, 0.963197f, 0.147801f, 0.256917f, 0.873557f, 0.491892f, 0.898961f, 0.185518f, 0.532669f, 0.326270f, 0.316543f, 0.446877f, 0.433077f, 0.357347f, 0.914971f, 0.731744f, 0.727547f, 0.289913f, 0.577709f, 0.779179f, 0.795590f, 0.344530f, 0.770873f, 0.735894f, 0.141506f, 0.865945f, 0.441321f, 0.486410f, 0.448369f, 0.567846f, 0.621169f, 0.498180f, 0.866789f, 0.627735f, 0.401428f, 0.416692f, 0.810839f, 0.348192f, 0.211455f, 0.059383f, 0.876027f, 0.918546f, 0.120120f, 0.334474f, 0.175372f, 0.115898f, 0.899867f, 0.056877f, 0.980486f, 0.096451f, 0.863471f, 0.566506f, 0.367917f, 0.342342f, 0.757364f, 0.314573f, 0.657319f, 0.517326f, 0.484966f, 0.901162f, 0.554645f, 0.826862f, 0.725574f, 0.038557f, 0.773110f, 0.216870f, 0.903150f, 0.042924f, 0.333072f, 0.099733f, 0.475589f, 0.820022f, 0.298187f, 0.150935f, 0.330267f, 0.813880f, 0.140384f, 0.227362f, 0.068852f, 0.705710f, 0.395233f, 0.310840f, 0.718626f, 0.335978f, 0.727771f, 0.815199f, 0.217663f, 0.973819f, 0.162358f, 0.290841f, 0.179795f, 0.345506f, 0.480061f, 0.522176f, 0.853606f, 0.889448f, 0.220104f, 0.622894f, 0.111496f, 0.458970f, 0.322334f, 0.316501f, 0.482584f, 0.729828f, 0.069183f, 0.879173f, 0.734814f, 0.176499f, 0.939161f, 0.506312f, 0.999809f, 0.197259f, 0.534908f, 0.290248f, 0.304174f, 0.591065f, 0.921719f, 0.805264f, 0.723941f, 0.559174f, 0.922298f, 0.492361f, 0.873832f, 0.833982f, 0.213835f, 0.771225f, 0.012171f, 0.322830f, 0.229567f, 0.506863f, 0.736853f, 0.097676f, 0.514922f, 0.938412f, 0.228647f, 0.677141f, 0.592880f, 0.010064f, 0.475826f, 0.708770f, 0.043975f, 0.879521f, 0.520081f, 0.030661f, 0.224414f, 0.953676f, 0.582320f, 0.107473f, 0.287544f, 0.456704f, 0.020950f, 0.411616f, 0.489459f, 0.243678f, 0.588639f, 0.753240f, 0.235834f, 0.620500f, 0.639622f, 0.948540f, 0.778276f, 0.848345f, 0.490420f, 0.185349f, 0.995815f, 0.129356f, 0.471457f, 0.068093f, 0.943851f, 0.964925f, 0.719389f, 0.349993f, 0.254382f, 0.265303f, 0.127294f, 0.525809f, 0.141817f, 0.316731f, 0.626706f, 0.727544f, 0.024273f, 0.430116f, 0.652125f, 0.853246f, 0.475325f, 0.969206f, 0.265633f, 0.013509f, 0.483753f, 0.256114f, 0.823718f, 0.232773f, 0.310629f, 0.791227f, 0.715143f, 0.558051f, 0.704948f, 0.418637f, 0.005310f, 0.011355f, 0.511222f, 0.083291f, 0.051075f, 0.965517f, 0.859003f, 0.152027f, 0.000664f, 0.941668f, 0.278325f, 0.185898f, 0.691508f, 0.108904f, 0.264650f, 0.975095f, 0.639463f, 0.520678f, 0.397919f, 0.774501f, 0.140957f, 0.967338f, 0.861123f, 0.617657f, 0.042906f, 0.700856f, 0.913284f, 0.524577f, 0.354225f, 0.120277f, 0.754901f, 0.885022f, 0.100252f, 0.758985f, 0.017060f, 0.967055f, 0.615058f, 0.552439f, 0.295950f, 0.929292f, 0.265906f, 0.828147f, 0.985109f, 0.783397f, 0.518990f, 0.066074f, 0.472414f, 0.438256f, 0.202796f, 0.423588f, 0.357758f, 0.163684f, 0.441374f, 0.262800f, 0.522062f, 0.035160f, 0.906231f, 0.816364f, 0.552581f, 0.851809f, 0.962395f, 0.110522f, 0.630832f, 0.997994f, 0.987889f, 0.603323f, 0.128021f, 0.583193f, 0.002065f, 0.198911f, 0.956123f, 0.330441f, 0.638390f, 0.280860f, 0.947822f, 0.728559f, 0.329651f, 0.791761f, 0.108166f, 0.392319f, 0.221218f, 0.683726f, 0.102446f, 0.397026f, 0.276650f, 0.506343f, 0.349898f, 0.706411f, 0.024577f, 0.633987f}; + // {2, 4, 72} + std::vector y = {0.532009f, 0.526025f, 0.449746f, 0.551692f, 0.407822f, 0.436275f, 0.507807f, 0.457324f, 0.530536f, 0.517111f, 0.452785f, 0.557318f, 0.397721f, 0.434161f, 0.498276f, 0.464536f, 0.528016f, 0.548671f, 0.441040f, 0.542961f, 0.418557f, 0.444397f, 0.515088f, 0.452512f, 0.462161f, 0.530536f, 0.564630f, 0.418701f, 0.669452f, 0.633554f, 0.569379f, 0.430544f, 0.456026f, 0.529795f, 0.558238f, 0.411985f, 0.664240f, 0.619959f, 0.590516f, 0.438577f, 0.471552f, 0.521718f, 0.560465f, 0.404206f, 0.663920f, 0.628819f, 0.540935f, 0.447763f, 0.615083f, 0.344791f, 0.432664f, 0.451253f, 0.460813f, 0.441267f, 0.708582f, 0.530088f, 0.623659f, 0.343547f, 0.439418f, 0.450767f, 0.460055f, 0.442001f, 0.703292f, 0.522883f, 0.617738f, 0.343160f, 0.440540f, 0.440079f, 0.459815f, 0.436860f, 0.703290f, 0.534856f, 0.536138f, 0.499439f, 0.465771f, 0.565138f, 0.391402f, 0.430258f, 0.494915f, 0.463613f, 0.532752f, 0.526358f, 0.452075f, 0.562130f, 0.402551f, 0.442784f, 0.486721f, 0.456955f, 0.547578f, 0.527342f, 0.453800f, 0.548887f, 0.418444f, 0.438968f, 0.515475f, 0.444207f, 0.475352f, 0.524010f, 0.549702f, 0.420030f, 0.656346f, 0.620729f, 0.571884f, 0.431010f, 0.453307f, 0.522210f, 0.563368f, 0.412061f, 0.657897f, 0.634999f, 0.577458f, 0.451691f, 0.473936f, 0.524285f, 0.553525f, 0.421768f, 0.662288f, 0.622833f, 0.570081f, 0.432808f, 0.625738f, 0.353159f, 0.436185f, 0.448597f, 0.459371f, 0.429822f, 0.709026f, 0.526207f, 0.630878f, 0.351036f, 0.439799f, 0.452249f, 0.456486f, 0.431906f, 0.706014f, 0.518897f, 0.629526f, 0.351482f, 0.440728f, 0.449287f, 0.451705f, 0.426815f, 0.706598f, 0.522028f, 0.537899f, 0.527199f, 0.447980f, 0.548688f, 0.410653f, 0.436181f, 0.511135f, 0.455244f, 0.534560f, 0.540045f, 0.447505f, 0.552786f, 0.413302f, 0.446360f, 0.499945f, 0.450757f, 0.531708f, 0.526097f, 0.450511f, 0.553372f, 0.401450f, 0.438186f, 0.501418f, 0.462466f, 0.469643f, 0.527539f, 0.553613f, 0.418159f, 0.659814f, 0.622731f, 0.575224f, 0.429425f, 0.463941f, 0.524481f, 0.557632f, 0.413729f, 0.657415f, 0.629157f, 0.570920f, 0.439773f, 0.479643f, 0.526773f, 0.556809f, 0.422406f, 0.670038f, 0.625300f, 0.554451f, 0.426587f, 0.630894f, 0.353011f, 0.444285f, 0.443177f, 0.448608f, 0.419312f, 0.705883f, 0.526260f, 0.631310f, 0.347563f, 0.445672f, 0.446224f, 0.448210f, 0.428481f, 0.702004f, 0.519990f, 0.626158f, 0.342802f, 0.449770f, 0.440666f, 0.453705f, 0.427492f, 0.700510f, 0.533279f, 0.526144f, 0.538202f, 0.443619f, 0.551579f, 0.407162f, 0.442426f, 0.499995f, 0.459987f, 0.525627f, 0.544718f, 0.448060f, 0.544942f, 0.415781f, 0.444198f, 0.516948f, 0.452985f, 0.521784f, 0.523083f, 0.450924f, 0.565538f, 0.392054f, 0.440702f, 0.479094f, 0.468113f, 0.473886f, 0.523677f, 0.555144f, 0.409412f, 0.664285f, 0.620163f, 0.555448f, 0.440947f, 0.459210f, 0.528829f, 0.567231f, 0.413602f, 0.672778f, 0.632467f, 0.565881f, 0.439895f, 0.480238f, 0.525127f, 0.554365f, 0.431656f, 0.658900f, 0.634358f, 0.561181f, 0.419623f, 0.646099f, 0.364754f, 0.442180f, 0.450340f, 0.441320f, 0.412523f, 0.708121f, 0.505939f, 0.641772f, 0.375478f, 0.428502f, 0.454772f, 0.439016f, 0.407773f, 0.718457f, 0.504047f, 0.628271f, 0.345239f, 0.449391f, 0.436208f, 0.448766f, 0.426444f, 0.699202f, 0.528374f, 0.489165f, 0.818278f, 0.467403f, 0.370507f, 0.572406f, 0.417942f, 0.160316f, 0.384139f, 0.497723f, 0.820329f, 0.455669f, 0.373132f, 0.568626f, 0.418602f, 0.164551f, 0.404233f, 0.488972f, 0.813399f, 0.460936f, 0.369774f, 0.580477f, 0.417018f, 0.167442f, 0.381535f, 0.603715f, 0.360599f, 0.371685f, 0.614777f, 0.440767f, 0.425124f, 0.369342f, 0.828101f, 0.584460f, 0.352249f, 0.382191f, 0.613073f, 0.431223f, 0.421802f, 0.389292f, 0.831202f, 0.590574f, 0.355658f, 0.373391f, 0.623741f, 0.432416f, 0.412097f, 0.378312f, 0.829226f, 0.365226f, 0.726961f, 0.549872f, 0.239494f, 0.496434f, 0.668542f, 0.557774f, 0.487281f, 0.361340f, 0.749156f, 0.523408f, 0.240555f, 0.493770f, 0.639516f, 0.552116f, 0.478230f, 0.367118f, 0.740114f, 0.563789f, 0.238852f, 0.498407f, 0.682064f, 0.571327f, 0.496416f, 0.480636f, 0.820258f, 0.464776f, 0.362168f, 0.567256f, 0.417842f, 0.161815f, 0.387104f, 0.486998f, 0.821507f, 0.467362f, 0.377934f, 0.569593f, 0.418367f, 0.156778f, 0.390179f, 0.461449f, 0.823726f, 0.471401f, 0.361646f, 0.563554f, 0.418609f, 0.154999f, 0.379696f, 0.565916f, 0.345293f, 0.392969f, 0.612305f, 0.418858f, 0.416238f, 0.410985f, 0.833515f, 0.552881f, 0.338985f, 0.394863f, 0.597100f, 0.422296f, 0.401025f, 0.427810f, 0.831702f, 0.558983f, 0.339943f, 0.393544f, 0.583418f, 0.432193f, 0.405729f, 0.426401f, 0.830305f, 0.362801f, 0.731181f, 0.546338f, 0.247016f, 0.499389f, 0.662441f, 0.544727f, 0.486631f, 0.355514f, 0.726998f, 0.518056f, 0.249475f, 0.492155f, 0.643678f, 0.531052f, 0.481617f, 0.370308f, 0.743741f, 0.562172f, 0.233361f, 0.498431f, 0.679567f, 0.580747f, 0.494199f, 0.481097f, 0.817782f, 0.461707f, 0.369188f, 0.573825f, 0.419752f, 0.161614f, 0.386708f, 0.472911f, 0.822003f, 0.473412f, 0.375830f, 0.569966f, 0.422158f, 0.149228f, 0.380008f, 0.454662f, 0.818956f, 0.465984f, 0.370169f, 0.575537f, 0.423344f, 0.153818f, 0.375466f, 0.572526f, 0.348075f, 0.380718f, 0.641409f, 0.417012f, 0.407621f, 0.389074f, 0.834251f, 0.581008f, 0.348183f, 0.383659f, 0.608061f, 0.435032f, 0.422240f, 0.393710f, 0.832528f, 0.600530f, 0.360439f, 0.371006f, 0.609018f, 0.441082f, 0.416286f, 0.374920f, 0.825853f, 0.364932f, 0.727047f, 0.540001f, 0.246375f, 0.501524f, 0.656266f, 0.541761f, 0.482865f, 0.360322f, 0.752650f, 0.542120f, 0.239561f, 0.491207f, 0.663446f, 0.566643f, 0.491988f, 0.364532f, 0.737402f, 0.546869f, 0.240953f, 0.497072f, 0.664793f, 0.558528f, 0.488182f, 0.490592f, 0.819727f, 0.468739f, 0.379671f, 0.572959f, 0.422399f, 0.152699f, 0.387445f, 0.462308f, 0.822644f, 0.463886f, 0.374320f, 0.569615f, 0.423238f, 0.152603f, 0.387850f, 0.451896f, 0.818576f, 0.449904f, 0.362889f, 0.573917f, 0.421849f, 0.165145f, 0.390440f, 0.565044f, 0.343397f, 0.395512f, 0.584043f, 0.431062f, 0.417783f, 0.421165f, 0.830938f, 0.583998f, 0.354061f, 0.374016f, 0.633981f, 0.424457f, 0.404069f, 0.381920f, 0.829920f, 0.568315f, 0.347357f, 0.386911f, 0.624227f, 0.418162f, 0.411256f, 0.400332f, 0.832994f, 0.370475f, 0.739716f, 0.551429f, 0.234114f, 0.499500f, 0.665245f, 0.570648f, 0.485298f, 0.364035f, 0.756092f, 0.542251f, 0.238706f, 0.495463f, 0.659518f, 0.567976f, 0.489204f, 0.368942f, 0.756397f, 0.548083f, 0.231854f, 0.496617f, 0.659726f, 0.578330f, 0.484921f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + ASSERT_EQ(y.size(), batch_size * q_num_heads * q_sequence_length * v_head_size); + + RunTest3D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), + -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, std::vector(), std::vector(), std::vector(), + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + TEST(AttentionTest, Attention4DGqaAttnMask) { int batch_size = 2; // Q.shape[0] int q_num_heads = 9; // Q.shape[1] @@ -847,7 +880,7 @@ TEST(AttentionTest, Attention4DGqaAttnMask) { // {4, 6} std::vector m = {0.230571f, 0.268709f, 0.800256f, 0.955568f, 0.316550f, 0.826805f, 0.103991f, 0.633982f, 0.751032f, 0.155978f, 0.426002f, 0.892707f, 0.103578f, 0.018096f, 0.590585f, 0.435532f, 0.798689f, 0.923456f, 0.299154f, 0.388404f, 0.486272f, 0.588151f, 0.983854f, 0.697330f}; // {2, 9, 4, 8} - std::vector y = {0.641842f, 0.667534f, 0.339592f, 0.480609f, 0.537525f, 0.340368f, 0.752882f, 0.387601f, 0.686814f, 0.643437f, 0.324983f, 0.468788f, 0.539061f, 0.319610f, 0.754181f, 0.373093f, 0.702380f, 0.693136f, 0.318406f, 0.456714f, 0.540838f, 0.315487f, 0.718291f, 0.311025f, 0.681769f, 0.670603f, 0.329705f, 0.456661f, 0.573902f, 0.337385f, 0.700597f, 0.333385f, 0.508992f, 0.253478f, 0.553979f, 0.466355f, 0.398637f, 0.412493f, 0.495810f, 0.677675f, 0.521609f, 0.278997f, 0.564189f, 0.434417f, 0.448085f, 0.467205f, 0.567856f, 0.664713f, 0.490146f, 0.261321f, 0.560582f, 0.424598f, 0.450318f, 0.467336f, 0.520983f, 0.720798f, 0.516095f, 0.264495f, 0.577940f, 0.475340f, 0.444145f, 0.477909f, 0.485663f, 0.672846f, 0.499389f, 0.402198f, 0.520218f, 0.550550f, 0.481065f, 0.730488f, 0.492535f, 0.392315f, 0.436722f, 0.398514f, 0.497457f, 0.502270f, 0.520993f, 0.730472f, 0.565429f, 0.380282f, 0.461226f, 0.392968f, 0.536035f, 0.505191f, 0.446570f, 0.751253f, 0.478584f, 0.389036f, 0.423738f, 0.443828f, 0.554323f, 0.462607f, 0.476656f, 0.733228f, 0.482219f, 0.411910f, 0.620556f, 0.662948f, 0.349409f, 0.482541f, 0.537250f, 0.351544f, 0.734285f, 0.397172f, 0.689500f, 0.637077f, 0.320710f, 0.470914f, 0.526307f, 0.312878f, 0.775762f, 0.384457f, 0.696615f, 0.681034f, 0.324383f, 0.459632f, 0.539497f, 0.317950f, 0.709736f, 0.320698f, 0.671696f, 0.676830f, 0.332387f, 0.453234f, 0.578648f, 0.345084f, 0.685369f, 0.328092f, 0.520830f, 0.251061f, 0.562824f, 0.469184f, 0.393635f, 0.405203f, 0.493565f, 0.668713f, 0.541328f, 0.282797f, 0.577903f, 0.434065f, 0.444664f, 0.460403f, 0.572628f, 0.646402f, 0.493508f, 0.265246f, 0.572078f, 0.418658f, 0.464491f, 0.483746f, 0.516536f, 0.724847f, 0.503705f, 0.270557f, 0.577678f, 0.465114f, 0.468430f, 0.508402f, 0.489087f, 0.689442f, 0.500042f, 0.410507f, 0.521381f, 0.553244f, 0.459062f, 0.719706f, 0.476571f, 0.395052f, 0.429926f, 0.408857f, 0.507006f, 0.493937f, 0.529878f, 0.728873f, 0.571495f, 0.376256f, 0.453676f, 0.380482f, 0.526100f, 0.496696f, 0.457383f, 0.761933f, 0.486657f, 0.396608f, 0.435748f, 0.432822f, 0.531763f, 0.482255f, 0.477046f, 0.726381f, 0.487480f, 0.416572f, 0.626676f, 0.683736f, 0.340657f, 0.475002f, 0.549981f, 0.353311f, 0.740157f, 0.378827f, 0.681403f, 0.636622f, 0.324593f, 0.469088f, 0.537323f, 0.321344f, 0.762506f, 0.384239f, 0.693108f, 0.683351f, 0.329873f, 0.460504f, 0.555115f, 0.325379f, 0.694659f, 0.316422f, 0.677285f, 0.670298f, 0.329724f, 0.456327f, 0.567533f, 0.337560f, 0.701396f, 0.336191f, 0.515940f, 0.251020f, 0.562035f, 0.442479f, 0.405802f, 0.410828f, 0.519841f, 0.686781f, 0.522057f, 0.285013f, 0.562761f, 0.453472f, 0.451971f, 0.481286f, 0.558322f, 0.649971f, 0.486787f, 0.258011f, 0.557963f, 0.426743f, 0.442028f, 0.457034f, 0.510534f, 0.724945f, 0.498901f, 0.272090f, 0.572650f, 0.467930f, 0.465335f, 0.506181f, 0.484559f, 0.690090f, 0.499525f, 0.398443f, 0.522291f, 0.550620f, 0.465209f, 0.731897f, 0.484389f, 0.388997f, 0.411109f, 0.420719f, 0.523354f, 0.478677f, 0.522513f, 0.723052f, 0.587358f, 0.350775f, 0.450881f, 0.384685f, 0.527140f, 0.502089f, 0.438660f, 0.749234f, 0.493312f, 0.377459f, 0.425945f, 0.432397f, 0.544111f, 0.466484f, 0.488077f, 0.738712f, 0.493642f, 0.412262f, 0.565934f, 0.795554f, 0.527262f, 0.295395f, 0.394937f, 0.326235f, 0.457519f, 0.454071f, 0.511390f, 0.753500f, 0.500815f, 0.303925f, 0.403792f, 0.343750f, 0.516333f, 0.463035f, 0.491925f, 0.753119f, 0.503555f, 0.310489f, 0.373396f, 0.334562f, 0.526486f, 0.470500f, 0.495985f, 0.733211f, 0.532951f, 0.342292f, 0.346065f, 0.355272f, 0.479542f, 0.509107f, 0.379088f, 0.582413f, 0.414383f, 0.571800f, 0.613176f, 0.687631f, 0.185596f, 0.656867f, 0.390452f, 0.532452f, 0.407547f, 0.564799f, 0.606499f, 0.653258f, 0.176547f, 0.698038f, 0.410398f, 0.604586f, 0.442972f, 0.497533f, 0.595085f, 0.732265f, 0.187201f, 0.663169f, 0.448716f, 0.590302f, 0.411879f, 0.518449f, 0.636722f, 0.695827f, 0.154292f, 0.666828f, 0.458054f, 0.608582f, 0.430376f, 0.316371f, 0.547620f, 0.542559f, 0.542043f, 0.556297f, 0.468371f, 0.559154f, 0.465195f, 0.344099f, 0.482571f, 0.527115f, 0.527529f, 0.616254f, 0.494566f, 0.605555f, 0.432360f, 0.382197f, 0.466678f, 0.556031f, 0.459313f, 0.588575f, 0.532798f, 0.597684f, 0.412305f, 0.393400f, 0.462773f, 0.491821f, 0.483189f, 0.593919f, 0.569241f, 0.793791f, 0.532988f, 0.300026f, 0.393843f, 0.327085f, 0.448199f, 0.457416f, 0.493302f, 0.725336f, 0.512066f, 0.327500f, 0.404238f, 0.351704f, 0.507818f, 0.477990f, 0.479548f, 0.756083f, 0.511730f, 0.309729f, 0.366024f, 0.338031f, 0.503335f, 0.472352f, 0.473026f, 0.696816f, 0.543129f, 0.374608f, 0.335432f, 0.360978f, 0.486364f, 0.531799f, 0.380422f, 0.599984f, 0.413640f, 0.564090f, 0.607571f, 0.708289f, 0.187551f, 0.671587f, 0.381058f, 0.550543f, 0.422336f, 0.556663f, 0.599418f, 0.666369f, 0.182365f, 0.678737f, 0.423800f, 0.600509f, 0.437094f, 0.494968f, 0.603340f, 0.727226f, 0.179659f, 0.667114f, 0.464399f, 0.563292f, 0.399716f, 0.529198f, 0.655782f, 0.666396f, 0.143497f, 0.659062f, 0.453034f, 0.596627f, 0.417365f, 0.314318f, 0.554269f, 0.518967f, 0.550250f, 0.556252f, 0.494918f, 0.587774f, 0.467566f, 0.350222f, 0.481994f, 0.538857f, 0.525631f, 0.605359f, 0.497486f, 0.608472f, 0.429145f, 0.384532f, 0.466790f, 0.554752f, 0.457698f, 0.586510f, 0.548577f, 0.604359f, 0.398097f, 0.414429f, 0.448200f, 0.485158f, 0.461395f, 0.593015f, 0.563470f, 0.796184f, 0.532783f, 0.293209f, 0.408910f, 0.327450f, 0.438028f, 0.447011f, 0.493041f, 0.739603f, 0.496957f, 0.311881f, 0.389768f, 0.352503f, 0.530113f, 0.476738f, 0.484897f, 0.752985f, 0.511921f, 0.312174f, 0.370408f, 0.339775f, 0.504061f, 0.473793f, 0.487978f, 0.714687f, 0.538817f, 0.358426f, 0.348908f, 0.355820f, 0.481380f, 0.516214f, 0.370872f, 0.602034f, 0.400225f, 0.611090f, 0.630508f, 0.662527f, 0.162489f, 0.658299f, 0.378734f, 0.537283f, 0.412214f, 0.570032f, 0.601452f, 0.653569f, 0.179932f, 0.693105f, 0.411981f, 0.605715f, 0.448022f, 0.481469f, 0.585099f, 0.748463f, 0.195177f, 0.671915f, 0.442141f, 0.581881f, 0.393362f, 0.555388f, 0.650764f, 0.665937f, 0.141141f, 0.675100f, 0.448606f, 0.605061f, 0.412183f, 0.312673f, 0.559178f, 0.530440f, 0.538275f, 0.546820f, 0.494936f, 0.585982f, 0.469875f, 0.355291f, 0.474437f, 0.542980f, 0.518181f, 0.609491f, 0.522046f, 0.618936f, 0.412090f, 0.410711f, 0.452217f, 0.540284f, 0.444109f, 0.585510f, 0.570158f, 0.614413f, 0.415425f, 0.410005f, 0.441791f, 0.491080f, 0.466021f, 0.595833f}; + std::vector y = {0.641842f, 0.667534f, 0.339592f, 0.480609f, 0.537525f, 0.340368f, 0.752882f, 0.387601f, 0.686814f, 0.643437f, 0.324983f, 0.468788f, 0.539061f, 0.319610f, 0.754181f, 0.373093f, 0.702380f, 0.693136f, 0.318406f, 0.456714f, 0.540838f, 0.315487f, 0.718291f, 0.311025f, 0.681769f, 0.670603f, 0.329705f, 0.456661f, 0.573902f, 0.337385f, 0.700597f, 0.333385f, 0.644472f, 0.666279f, 0.336558f, 0.478260f, 0.534820f, 0.338286f, 0.756443f, 0.387184f, 0.674255f, 0.645509f, 0.327427f, 0.465534f, 0.543598f, 0.328256f, 0.743604f, 0.373978f, 0.689753f, 0.687485f, 0.332246f, 0.457085f, 0.565540f, 0.331625f, 0.677863f, 0.308191f, 0.663033f, 0.669169f, 0.333832f, 0.452516f, 0.576569f, 0.348823f, 0.685447f, 0.338196f, 0.613061f, 0.681689f, 0.345384f, 0.474784f, 0.541609f, 0.357958f, 0.728217f, 0.383408f, 0.680108f, 0.637886f, 0.329455f, 0.469504f, 0.544973f, 0.325193f, 0.745572f, 0.378169f, 0.695405f, 0.687321f, 0.323229f, 0.456101f, 0.553544f, 0.323743f, 0.706057f, 0.314785f, 0.672814f, 0.678842f, 0.323628f, 0.449345f, 0.572724f, 0.342071f, 0.707722f, 0.332714f, 0.512254f, 0.252087f, 0.555774f, 0.456582f, 0.393340f, 0.400567f, 0.501655f, 0.680466f, 0.530775f, 0.288611f, 0.570275f, 0.444357f, 0.454871f, 0.480588f, 0.567893f, 0.645871f, 0.491847f, 0.262209f, 0.561930f, 0.418081f, 0.444398f, 0.456345f, 0.519658f, 0.722565f, 0.523232f, 0.267034f, 0.591659f, 0.459565f, 0.462164f, 0.494775f, 0.497558f, 0.678628f, 0.520830f, 0.251061f, 0.562824f, 0.469184f, 0.393635f, 0.405203f, 0.493565f, 0.668713f, 0.541328f, 0.282797f, 0.577903f, 0.434065f, 0.444664f, 0.460403f, 0.572628f, 0.646402f, 0.493508f, 0.265246f, 0.572078f, 0.418658f, 0.464491f, 0.483746f, 0.516536f, 0.724847f, 0.503705f, 0.270557f, 0.577678f, 0.465114f, 0.468430f, 0.508402f, 0.489087f, 0.689442f, 0.513034f, 0.252153f, 0.561841f, 0.455825f, 0.411518f, 0.424734f, 0.508095f, 0.683202f, 0.537475f, 0.278680f, 0.572605f, 0.449901f, 0.433722f, 0.452424f, 0.554372f, 0.643199f, 0.503808f, 0.259719f, 0.571011f, 0.415224f, 0.442363f, 0.450636f, 0.525191f, 0.716156f, 0.524579f, 0.263175f, 0.588806f, 0.462952f, 0.450874f, 0.480435f, 0.495070f, 0.675950f, 0.503113f, 0.409947f, 0.538941f, 0.550010f, 0.457564f, 0.729741f, 0.472483f, 0.384586f, 0.421666f, 0.416784f, 0.522405f, 0.484472f, 0.519795f, 0.728113f, 0.570887f, 0.363251f, 0.462182f, 0.372738f, 0.510951f, 0.511798f, 0.446353f, 0.754695f, 0.485592f, 0.397135f, 0.421437f, 0.447040f, 0.546262f, 0.462919f, 0.473860f, 0.726421f, 0.479062f, 0.420641f, 0.498228f, 0.402912f, 0.524895f, 0.548811f, 0.462668f, 0.729601f, 0.480759f, 0.390396f, 0.421638f, 0.418506f, 0.518644f, 0.484993f, 0.512452f, 0.724489f, 0.562537f, 0.370564f, 0.461864f, 0.376424f, 0.511195f, 0.510163f, 0.461531f, 0.755198f, 0.491549f, 0.400847f, 0.425338f, 0.456035f, 0.553542f, 0.466468f, 0.482400f, 0.722062f, 0.483532f, 0.415135f, 0.499525f, 0.398443f, 0.522291f, 0.550620f, 0.465209f, 0.731897f, 0.484389f, 0.388997f, 0.411109f, 0.420719f, 0.523354f, 0.478677f, 0.522513f, 0.723052f, 0.587358f, 0.350775f, 0.450881f, 0.384685f, 0.527140f, 0.502089f, 0.438660f, 0.749234f, 0.493312f, 0.377459f, 0.425945f, 0.432397f, 0.544111f, 0.466484f, 0.488077f, 0.738712f, 0.493642f, 0.412262f, 0.565934f, 0.795554f, 0.527262f, 0.295395f, 0.394937f, 0.326235f, 0.457519f, 0.454071f, 0.511390f, 0.753500f, 0.500815f, 0.303925f, 0.403792f, 0.343750f, 0.516333f, 0.463035f, 0.491925f, 0.753119f, 0.503555f, 0.310489f, 0.373396f, 0.334562f, 0.526486f, 0.470500f, 0.495985f, 0.733211f, 0.532951f, 0.342292f, 0.346065f, 0.355272f, 0.479542f, 0.509107f, 0.560779f, 0.795626f, 0.527843f, 0.292198f, 0.403399f, 0.328103f, 0.449548f, 0.449270f, 0.492632f, 0.741337f, 0.501964f, 0.308729f, 0.404425f, 0.353946f, 0.510715f, 0.469292f, 0.498506f, 0.749246f, 0.510938f, 0.317603f, 0.377607f, 0.333171f, 0.516589f, 0.472113f, 0.494030f, 0.738331f, 0.525273f, 0.334388f, 0.351797f, 0.349013f, 0.492978f, 0.499192f, 0.558701f, 0.785575f, 0.541472f, 0.309741f, 0.379566f, 0.336180f, 0.433460f, 0.471779f, 0.500494f, 0.748997f, 0.495158f, 0.302537f, 0.401868f, 0.348977f, 0.525071f, 0.465493f, 0.496427f, 0.763380f, 0.504640f, 0.303037f, 0.375539f, 0.332025f, 0.517142f, 0.464096f, 0.466789f, 0.731320f, 0.529262f, 0.338950f, 0.329005f, 0.361720f, 0.481664f, 0.514476f, 0.356477f, 0.623874f, 0.420893f, 0.592125f, 0.610336f, 0.687956f, 0.174269f, 0.652548f, 0.366057f, 0.567382f, 0.428770f, 0.553226f, 0.582617f, 0.683498f, 0.188604f, 0.695704f, 0.406930f, 0.625170f, 0.441775f, 0.499327f, 0.590722f, 0.740689f, 0.180721f, 0.681143f, 0.430954f, 0.584531f, 0.412720f, 0.532459f, 0.630830f, 0.690216f, 0.161882f, 0.663851f, 0.380422f, 0.599984f, 0.413640f, 0.564090f, 0.607571f, 0.708289f, 0.187551f, 0.671587f, 0.381058f, 0.550543f, 0.422336f, 0.556663f, 0.599418f, 0.666369f, 0.182365f, 0.678737f, 0.423800f, 0.600509f, 0.437094f, 0.494968f, 0.603340f, 0.727226f, 0.179659f, 0.667114f, 0.464399f, 0.563292f, 0.399716f, 0.529198f, 0.655782f, 0.666396f, 0.143497f, 0.659062f, 0.365268f, 0.611770f, 0.413907f, 0.600775f, 0.622849f, 0.667798f, 0.164152f, 0.647839f, 0.377540f, 0.543255f, 0.401769f, 0.588162f, 0.610896f, 0.645976f, 0.172500f, 0.695675f, 0.428349f, 0.590245f, 0.429343f, 0.497694f, 0.606978f, 0.727059f, 0.182826f, 0.671502f, 0.466759f, 0.580932f, 0.396764f, 0.527984f, 0.655065f, 0.677027f, 0.138356f, 0.672848f, 0.431113f, 0.593599f, 0.391529f, 0.327778f, 0.551802f, 0.526872f, 0.512055f, 0.547473f, 0.461591f, 0.564565f, 0.469932f, 0.335454f, 0.493299f, 0.536959f, 0.537769f, 0.611109f, 0.505296f, 0.606927f, 0.414343f, 0.395585f, 0.462205f, 0.538029f, 0.450814f, 0.585742f, 0.550355f, 0.606479f, 0.419783f, 0.396625f, 0.449703f, 0.500831f, 0.464506f, 0.594653f, 0.460993f, 0.609826f, 0.424563f, 0.322395f, 0.546231f, 0.537700f, 0.541169f, 0.555672f, 0.479953f, 0.573210f, 0.449011f, 0.356276f, 0.482535f, 0.523785f, 0.516393f, 0.605958f, 0.473948f, 0.587667f, 0.412118f, 0.378344f, 0.472903f, 0.540161f, 0.445341f, 0.585184f, 0.561693f, 0.609513f, 0.394200f, 0.418769f, 0.444939f, 0.478136f, 0.458334f, 0.591187f, 0.448606f, 0.605061f, 0.412183f, 0.312673f, 0.559178f, 0.530440f, 0.538275f, 0.546820f, 0.494936f, 0.585982f, 0.469875f, 0.355291f, 0.474437f, 0.542980f, 0.518181f, 0.609491f, 0.522046f, 0.618936f, 0.412090f, 0.410711f, 0.452217f, 0.540284f, 0.444109f, 0.585510f, 0.570158f, 0.614413f, 0.415425f, 0.410005f, 0.441791f, 0.491080f, 0.466021f, 0.595833f}; ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); @@ -886,7 +919,7 @@ TEST(AttentionTest, Attention4DGqaWithPastAndPresent) { // {2, 3, 12, 8} std::vector past_value = {0.546950f, 0.293617f, 0.968204f, 0.226196f, 0.015738f, 0.325855f, 0.502509f, 0.028363f, 0.559248f, 0.874283f, 0.704732f, 0.622968f, 0.955962f, 0.958279f, 0.824266f, 0.607742f, 0.487765f, 0.013316f, 0.606262f, 0.989088f, 0.818101f, 0.340605f, 0.152047f, 0.784059f, 0.743938f, 0.967047f, 0.874842f, 0.555663f, 0.101284f, 0.483501f, 0.313695f, 0.512408f, 0.301702f, 0.861823f, 0.844327f, 0.315465f, 0.599581f, 0.430181f, 0.909093f, 0.187361f, 0.697728f, 0.970375f, 0.175276f, 0.201966f, 0.693723f, 0.779154f, 0.490549f, 0.609686f, 0.212682f, 0.476614f, 0.112072f, 0.321422f, 0.284780f, 0.444625f, 0.930126f, 0.181268f, 0.401388f, 0.615597f, 0.946557f, 0.133148f, 0.917877f, 0.081054f, 0.480741f, 0.454590f, 0.209603f, 0.347460f, 0.454165f, 0.865211f, 0.955064f, 0.518926f, 0.870100f, 0.608172f, 0.349087f, 0.194194f, 0.413135f, 0.522824f, 0.044443f, 0.145841f, 0.600184f, 0.225002f, 0.837326f, 0.326942f, 0.104834f, 0.083531f, 0.937123f, 0.118020f, 0.140910f, 0.862666f, 0.254288f, 0.665951f, 0.816726f, 0.607181f, 0.957489f, 0.708883f, 0.112752f, 0.558410f, 0.718186f, 0.801957f, 0.026321f, 0.718879f, 0.825681f, 0.746834f, 0.512349f, 0.458021f, 0.549419f, 0.704644f, 0.922914f, 0.617035f, 0.887834f, 0.701257f, 0.068336f, 0.500828f, 0.286486f, 0.285175f, 0.355928f, 0.314733f, 0.578610f, 0.683601f, 0.268749f, 0.129763f, 0.058809f, 0.575753f, 0.186130f, 0.009248f, 0.927753f, 0.537140f, 0.092448f, 0.842921f, 0.983203f, 0.448601f, 0.042490f, 0.117546f, 0.381654f, 0.885523f, 0.148039f, 0.823990f, 0.014976f, 0.457389f, 0.644397f, 0.060379f, 0.614763f, 0.944404f, 0.160260f, 0.729611f, 0.609094f, 0.185116f, 0.006203f, 0.009284f, 0.532092f, 0.942779f, 0.644299f, 0.714300f, 0.493865f, 0.581889f, 0.126368f, 0.876821f, 0.760793f, 0.998199f, 0.297723f, 0.227018f, 0.125162f, 0.964210f, 0.780885f, 0.166325f, 0.552686f, 0.413768f, 0.151486f, 0.162073f, 0.963470f, 0.304964f, 0.941439f, 0.075611f, 0.460803f, 0.129619f, 0.004787f, 0.553766f, 0.113894f, 0.722025f, 0.698116f, 0.176333f, 0.941742f, 0.721043f, 0.297970f, 0.709234f, 0.731930f, 0.342226f, 0.375589f, 0.359107f, 0.616618f, 0.900410f, 0.173193f, 0.875200f, 0.027653f, 0.660339f, 0.414439f, 0.791282f, 0.721198f, 0.480108f, 0.643864f, 0.501773f, 0.811518f, 0.476084f, 0.523156f, 0.250521f, 0.605043f, 0.302905f, 0.577284f, 0.169678f, 0.159469f, 0.417030f, 0.426820f, 0.268109f, 0.131597f, 0.039211f, 0.025232f, 0.271550f, 0.461853f, 0.726243f, 0.474872f, 0.904051f, 0.035220f, 0.180661f, 0.338515f, 0.577496f, 0.852736f, 0.350202f, 0.267989f, 0.061889f, 0.821303f, 0.379666f, 0.571550f, 0.983555f, 0.001595f, 0.145450f, 0.779111f, 0.805128f, 0.769247f, 0.536999f, 0.978857f, 0.396185f, 0.601944f, 0.063369f, 0.409857f, 0.722500f, 0.238739f, 0.943828f, 0.686783f, 0.287575f, 0.768999f, 0.083165f, 0.974774f, 0.049285f, 0.933456f, 0.252854f, 0.757824f, 0.000074f, 0.254240f, 0.749101f, 0.532336f, 0.114952f, 0.393630f, 0.375549f, 0.568162f, 0.667977f, 0.840830f, 0.497231f, 0.392022f, 0.143977f, 0.804823f, 0.713370f, 0.408677f, 0.518432f, 0.665183f, 0.164806f, 0.027198f, 0.317504f, 0.595585f, 0.486606f, 0.692555f, 0.819690f, 0.488442f, 0.134267f, 0.850628f, 0.574990f, 0.739937f, 0.704665f, 0.968212f, 0.295307f, 0.705307f, 0.365676f, 0.395411f, 0.230595f, 0.344010f, 0.948297f, 0.292571f, 0.245991f, 0.583138f, 0.258036f, 0.473386f, 0.834176f, 0.230400f, 0.426691f, 0.610490f, 0.545629f, 0.974723f, 0.680370f, 0.739946f, 0.966956f, 0.414438f, 0.355380f, 0.043862f, 0.184204f, 0.237190f, 0.183504f, 0.754784f, 0.535883f, 0.667634f, 0.820462f, 0.230774f, 0.325924f, 0.708360f, 0.392759f, 0.029271f, 0.434955f, 0.908273f, 0.409021f, 0.332249f, 0.989525f, 0.644416f, 0.365998f, 0.102020f, 0.787849f, 0.708075f, 0.921916f, 0.217276f, 0.114924f, 0.724073f, 0.203396f, 0.176104f, 0.319807f, 0.816825f, 0.539537f, 0.045850f, 0.463895f, 0.683980f, 0.538368f, 0.572450f, 0.224777f, 0.847739f, 0.561399f, 0.713246f, 0.981864f, 0.428199f, 0.881067f, 0.007281f, 0.033407f, 0.590280f, 0.311449f, 0.248277f, 0.277935f, 0.318403f, 0.728948f, 0.569196f, 0.789036f, 0.830197f, 0.842935f, 0.414644f, 0.421273f, 0.926266f, 0.661764f, 0.080467f, 0.542187f, 0.356007f, 0.987435f, 0.013655f, 0.612181f, 0.723623f, 0.288907f, 0.973642f, 0.859537f, 0.915653f, 0.019232f, 0.569872f, 0.294650f, 0.849029f, 0.632850f, 0.538877f, 0.114588f, 0.540223f, 0.631904f, 0.955912f, 0.585051f, 0.967401f, 0.961606f, 0.650200f, 0.505908f, 0.466022f, 0.890379f, 0.028257f, 0.113808f, 0.102072f, 0.756935f, 0.339651f, 0.637969f, 0.603783f, 0.385828f, 0.531568f, 0.645139f, 0.940950f, 0.575634f, 0.614367f, 0.067856f, 0.952216f, 0.528082f, 0.801273f, 0.050291f, 0.420910f, 0.256975f, 0.266976f, 0.791454f, 0.623867f, 0.439745f, 0.010586f, 0.964928f, 0.962023f, 0.217552f, 0.041346f, 0.530199f, 0.951411f, 0.910396f, 0.584663f, 0.303549f, 0.329961f, 0.897914f, 0.491784f, 0.131116f, 0.248425f, 0.276795f, 0.123547f, 0.463044f, 0.916051f, 0.668783f, 0.072474f, 0.005495f, 0.276248f, 0.362693f, 0.776750f, 0.967006f, 0.387567f, 0.686690f, 0.994902f, 0.745667f, 0.636190f, 0.078075f, 0.323215f, 0.913392f, 0.201005f, 0.843590f, 0.696324f, 0.366324f, 0.529174f, 0.542806f, 0.714054f, 0.516556f, 0.133076f, 0.773455f, 0.406273f, 0.963094f, 0.283514f, 0.263079f, 0.333507f, 0.572317f, 0.894870f, 0.176282f, 0.279679f, 0.581680f, 0.454334f, 0.447323f, 0.820734f, 0.923878f, 0.481307f, 0.687352f, 0.801059f, 0.518366f, 0.294316f, 0.638085f, 0.585109f, 0.901563f, 0.052407f, 0.910131f, 0.534432f, 0.015676f, 0.344702f, 0.724334f, 0.488433f, 0.980159f, 0.422610f, 0.326635f, 0.821672f, 0.547907f, 0.682327f, 0.805702f, 0.671428f, 0.422408f, 0.124796f, 0.580248f, 0.897433f, 0.418892f, 0.910725f, 0.503528f, 0.620842f, 0.832989f, 0.564597f, 0.090969f, 0.980979f, 0.245849f, 0.710505f, 0.505113f, 0.478773f, 0.243941f, 0.722151f, 0.112788f, 0.990453f, 0.845374f, 0.534509f, 0.424553f, 0.286465f, 0.501591f, 0.879417f, 0.275006f, 0.500537f, 0.234550f, 0.337149f, 0.190261f, 0.990539f, 0.571497f, 0.732815f, 0.098250f, 0.366118f, 0.892640f, 0.084438f, 0.165483f, 0.625418f, 0.622789f, 0.838227f, 0.935493f, 0.141986f, 0.259374f, 0.427461f, 0.000903f, 0.069814f, 0.226491f, 0.481102f, 0.251523f, 0.876682f, 0.324273f, 0.924623f, 0.974787f, 0.449862f, 0.227129f, 0.291666f, 0.776334f, 0.273350f, 0.380583f, 0.478576f, 0.575111f, 0.996100f, 0.232210f, 0.353424f, 0.262891f, 0.361113f, 0.100805f, 0.359810f, 0.887865f, 0.298590f, 0.371935f}; // {2, 9, 4, 8} - std::vector y = {0.544462f, 0.617844f, 0.506335f, 0.473482f, 0.606855f, 0.423464f, 0.544771f, 0.450451f, 0.524249f, 0.627160f, 0.497201f, 0.440288f, 0.619110f, 0.437084f, 0.563680f, 0.440037f, 0.516736f, 0.577726f, 0.523888f, 0.493471f, 0.594122f, 0.433401f, 0.585942f, 0.457686f, 0.528512f, 0.604578f, 0.472106f, 0.471486f, 0.600445f, 0.446256f, 0.622393f, 0.435442f, 0.440810f, 0.437705f, 0.476508f, 0.320820f, 0.605191f, 0.640150f, 0.306216f, 0.610947f, 0.485794f, 0.448216f, 0.485639f, 0.323744f, 0.594446f, 0.646597f, 0.321742f, 0.605751f, 0.501858f, 0.445502f, 0.487899f, 0.384660f, 0.597134f, 0.616430f, 0.331401f, 0.566459f, 0.502522f, 0.409965f, 0.526639f, 0.348601f, 0.565200f, 0.586558f, 0.325044f, 0.603422f, 0.450250f, 0.368009f, 0.550911f, 0.460338f, 0.523907f, 0.508816f, 0.575624f, 0.426601f, 0.472310f, 0.372844f, 0.517852f, 0.431688f, 0.551555f, 0.527657f, 0.600578f, 0.473069f, 0.456633f, 0.442035f, 0.539875f, 0.437863f, 0.540202f, 0.499608f, 0.556470f, 0.419831f, 0.463081f, 0.416724f, 0.526389f, 0.458654f, 0.540120f, 0.551554f, 0.569399f, 0.447102f, 0.534296f, 0.597655f, 0.509699f, 0.487167f, 0.607438f, 0.426383f, 0.522794f, 0.458435f, 0.510147f, 0.622761f, 0.501724f, 0.453386f, 0.629671f, 0.434103f, 0.582477f, 0.437681f, 0.520031f, 0.568543f, 0.525216f, 0.490370f, 0.571745f, 0.428629f, 0.572995f, 0.460086f, 0.533607f, 0.614962f, 0.474130f, 0.456345f, 0.576467f, 0.448127f, 0.599211f, 0.432252f, 0.447842f, 0.430169f, 0.480055f, 0.320521f, 0.590915f, 0.627003f, 0.314551f, 0.609320f, 0.499216f, 0.438828f, 0.485519f, 0.322134f, 0.586364f, 0.645824f, 0.326481f, 0.596989f, 0.496362f, 0.442741f, 0.492120f, 0.366111f, 0.601604f, 0.615566f, 0.326354f, 0.567173f, 0.496946f, 0.422179f, 0.533144f, 0.342588f, 0.590482f, 0.605923f, 0.318055f, 0.610401f, 0.452598f, 0.361594f, 0.550919f, 0.455099f, 0.530404f, 0.519313f, 0.588655f, 0.431890f, 0.464325f, 0.389636f, 0.515359f, 0.429087f, 0.540767f, 0.518376f, 0.586627f, 0.471074f, 0.458527f, 0.422216f, 0.537762f, 0.434123f, 0.550956f, 0.507704f, 0.564828f, 0.421548f, 0.463044f, 0.407985f, 0.523093f, 0.473684f, 0.542663f, 0.551348f, 0.576783f, 0.448743f, 0.546208f, 0.621128f, 0.501647f, 0.468191f, 0.612298f, 0.425183f, 0.549241f, 0.447622f, 0.519355f, 0.619636f, 0.487775f, 0.444259f, 0.625749f, 0.430264f, 0.584338f, 0.436887f, 0.521021f, 0.572716f, 0.522539f, 0.486440f, 0.581317f, 0.429079f, 0.579691f, 0.455426f, 0.526431f, 0.604615f, 0.476481f, 0.469814f, 0.588766f, 0.445640f, 0.609160f, 0.437785f, 0.443498f, 0.439338f, 0.487424f, 0.310942f, 0.607341f, 0.630362f, 0.312591f, 0.621999f, 0.483917f, 0.446308f, 0.477454f, 0.331028f, 0.592608f, 0.653297f, 0.322368f, 0.599377f, 0.497354f, 0.443447f, 0.477781f, 0.384002f, 0.591587f, 0.610287f, 0.328537f, 0.567630f, 0.499369f, 0.421961f, 0.536492f, 0.345379f, 0.586450f, 0.600541f, 0.312965f, 0.609437f, 0.451750f, 0.359685f, 0.553321f, 0.464992f, 0.524025f, 0.522507f, 0.582135f, 0.425124f, 0.459696f, 0.394679f, 0.519051f, 0.411226f, 0.539772f, 0.505003f, 0.587681f, 0.469383f, 0.451681f, 0.430062f, 0.541843f, 0.420929f, 0.542240f, 0.487570f, 0.567067f, 0.419708f, 0.456288f, 0.412096f, 0.527592f, 0.467870f, 0.545021f, 0.547842f, 0.573135f, 0.448166f, 0.581220f, 0.559255f, 0.469802f, 0.489935f, 0.557197f, 0.487135f, 0.377325f, 0.425637f, 0.582374f, 0.560738f, 0.425382f, 0.463129f, 0.549939f, 0.481810f, 0.350432f, 0.466049f, 0.593554f, 0.542315f, 0.482597f, 0.496969f, 0.518851f, 0.507807f, 0.366054f, 0.457476f, 0.569468f, 0.565965f, 0.444765f, 0.465404f, 0.515500f, 0.520271f, 0.337845f, 0.448357f, 0.557802f, 0.585925f, 0.426858f, 0.464044f, 0.585251f, 0.557395f, 0.433327f, 0.615342f, 0.534368f, 0.573723f, 0.426393f, 0.518102f, 0.586735f, 0.513129f, 0.371969f, 0.636735f, 0.544166f, 0.588469f, 0.433470f, 0.481894f, 0.595019f, 0.533156f, 0.396519f, 0.608115f, 0.547125f, 0.604473f, 0.441984f, 0.469765f, 0.599107f, 0.561685f, 0.347618f, 0.563457f, 0.507550f, 0.485293f, 0.545846f, 0.408434f, 0.482538f, 0.532314f, 0.498883f, 0.525126f, 0.514603f, 0.471457f, 0.539705f, 0.362410f, 0.490158f, 0.513690f, 0.494170f, 0.496909f, 0.492936f, 0.506153f, 0.565865f, 0.364727f, 0.508899f, 0.516217f, 0.558362f, 0.556920f, 0.530472f, 0.521715f, 0.554673f, 0.363830f, 0.509086f, 0.511590f, 0.552396f, 0.541486f, 0.572145f, 0.551531f, 0.471964f, 0.485188f, 0.555030f, 0.493247f, 0.376875f, 0.429387f, 0.580540f, 0.550944f, 0.435664f, 0.480675f, 0.544997f, 0.488698f, 0.344985f, 0.464878f, 0.593774f, 0.541202f, 0.484834f, 0.497316f, 0.509364f, 0.500045f, 0.357235f, 0.448933f, 0.565242f, 0.546653f, 0.459790f, 0.481954f, 0.514950f, 0.516297f, 0.344285f, 0.454476f, 0.548036f, 0.577907f, 0.427075f, 0.478978f, 0.581563f, 0.553606f, 0.426476f, 0.638442f, 0.498925f, 0.598346f, 0.444106f, 0.536998f, 0.575948f, 0.499260f, 0.371120f, 0.626981f, 0.545949f, 0.586548f, 0.428254f, 0.479753f, 0.596943f, 0.527697f, 0.401418f, 0.613028f, 0.542355f, 0.607063f, 0.447840f, 0.467102f, 0.603496f, 0.549575f, 0.364370f, 0.561534f, 0.507041f, 0.473640f, 0.547768f, 0.413960f, 0.490513f, 0.534377f, 0.497277f, 0.517772f, 0.531394f, 0.489105f, 0.531671f, 0.369343f, 0.486462f, 0.501787f, 0.494220f, 0.493498f, 0.485968f, 0.510301f, 0.559766f, 0.361474f, 0.507888f, 0.518858f, 0.564300f, 0.561990f, 0.537984f, 0.527982f, 0.539571f, 0.366920f, 0.498313f, 0.505709f, 0.538027f, 0.541246f, 0.585733f, 0.565800f, 0.441346f, 0.476255f, 0.556453f, 0.497693f, 0.363246f, 0.426799f, 0.578484f, 0.556489f, 0.436699f, 0.481177f, 0.549473f, 0.484153f, 0.355910f, 0.462010f, 0.590951f, 0.542803f, 0.470954f, 0.488994f, 0.512707f, 0.511876f, 0.358555f, 0.455953f, 0.559449f, 0.546003f, 0.462900f, 0.471080f, 0.517298f, 0.519225f, 0.345016f, 0.449149f, 0.526624f, 0.606761f, 0.427660f, 0.480775f, 0.577420f, 0.538850f, 0.426959f, 0.625509f, 0.530502f, 0.585784f, 0.432234f, 0.516800f, 0.584937f, 0.514154f, 0.373726f, 0.623740f, 0.550470f, 0.585577f, 0.436483f, 0.474799f, 0.594100f, 0.540052f, 0.402520f, 0.607686f, 0.537556f, 0.609680f, 0.439490f, 0.477886f, 0.602656f, 0.542957f, 0.350394f, 0.574553f, 0.506900f, 0.488792f, 0.539037f, 0.403028f, 0.494093f, 0.534739f, 0.494292f, 0.511628f, 0.528192f, 0.480037f, 0.546429f, 0.375120f, 0.484828f, 0.505006f, 0.495786f, 0.497935f, 0.502174f, 0.514122f, 0.541314f, 0.369540f, 0.493985f, 0.508263f, 0.550415f, 0.556157f, 0.543269f, 0.529970f, 0.562027f, 0.376526f, 0.499704f, 0.508621f, 0.536068f, 0.545993f}; + std::vector y = {0.544462f, 0.617844f, 0.506335f, 0.473482f, 0.606855f, 0.423464f, 0.544771f, 0.450451f, 0.524249f, 0.627160f, 0.497201f, 0.440288f, 0.619110f, 0.437084f, 0.563680f, 0.440037f, 0.516736f, 0.577726f, 0.523888f, 0.493471f, 0.594122f, 0.433401f, 0.585942f, 0.457686f, 0.528512f, 0.604578f, 0.472106f, 0.471486f, 0.600445f, 0.446256f, 0.622393f, 0.435442f, 0.546090f, 0.618047f, 0.504325f, 0.472246f, 0.609686f, 0.422467f, 0.546964f, 0.451166f, 0.519404f, 0.617868f, 0.491984f, 0.445771f, 0.633094f, 0.436822f, 0.559753f, 0.447209f, 0.519860f, 0.574899f, 0.525759f, 0.489339f, 0.586803f, 0.436452f, 0.577737f, 0.453299f, 0.532473f, 0.609446f, 0.471758f, 0.455772f, 0.573504f, 0.445466f, 0.602573f, 0.433307f, 0.538062f, 0.604199f, 0.500302f, 0.479569f, 0.614174f, 0.429231f, 0.522434f, 0.459369f, 0.528422f, 0.620683f, 0.485333f, 0.435606f, 0.616579f, 0.432233f, 0.565856f, 0.440093f, 0.525356f, 0.580613f, 0.529584f, 0.483095f, 0.583395f, 0.433491f, 0.593043f, 0.451879f, 0.540119f, 0.622995f, 0.472122f, 0.449888f, 0.586202f, 0.447435f, 0.611846f, 0.434879f, 0.449905f, 0.430732f, 0.474834f, 0.321674f, 0.590495f, 0.626300f, 0.319127f, 0.606006f, 0.492763f, 0.445330f, 0.490219f, 0.319940f, 0.588298f, 0.643644f, 0.317760f, 0.596360f, 0.507993f, 0.440004f, 0.490555f, 0.378128f, 0.588227f, 0.604974f, 0.329202f, 0.561987f, 0.511572f, 0.403440f, 0.542761f, 0.331792f, 0.568397f, 0.583366f, 0.333122f, 0.608456f, 0.447842f, 0.430169f, 0.480055f, 0.320521f, 0.590915f, 0.627003f, 0.314551f, 0.609320f, 0.499216f, 0.438828f, 0.485519f, 0.322134f, 0.586364f, 0.645824f, 0.326481f, 0.596989f, 0.496362f, 0.442741f, 0.492120f, 0.366111f, 0.601604f, 0.615566f, 0.326354f, 0.567173f, 0.496946f, 0.422179f, 0.533144f, 0.342588f, 0.590482f, 0.605923f, 0.318055f, 0.610401f, 0.441356f, 0.431701f, 0.488343f, 0.311828f, 0.606159f, 0.632821f, 0.317863f, 0.629084f, 0.495613f, 0.441177f, 0.473223f, 0.335484f, 0.579139f, 0.646878f, 0.321269f, 0.595437f, 0.504999f, 0.443626f, 0.498154f, 0.369326f, 0.588410f, 0.600189f, 0.322347f, 0.562676f, 0.508419f, 0.405342f, 0.533092f, 0.335876f, 0.570568f, 0.589600f, 0.330741f, 0.609168f, 0.456943f, 0.365603f, 0.555030f, 0.454344f, 0.526263f, 0.519062f, 0.578652f, 0.425453f, 0.464039f, 0.391848f, 0.518985f, 0.419419f, 0.541410f, 0.514459f, 0.586459f, 0.470210f, 0.460338f, 0.408599f, 0.539512f, 0.446249f, 0.551945f, 0.511356f, 0.575513f, 0.424325f, 0.452212f, 0.418205f, 0.525148f, 0.459799f, 0.536327f, 0.541881f, 0.571451f, 0.452969f, 0.454154f, 0.354641f, 0.553889f, 0.451027f, 0.536270f, 0.521832f, 0.590756f, 0.429859f, 0.459101f, 0.394962f, 0.512076f, 0.419296f, 0.535702f, 0.516757f, 0.585606f, 0.478117f, 0.458365f, 0.422929f, 0.531943f, 0.447581f, 0.546387f, 0.511705f, 0.564350f, 0.425332f, 0.463274f, 0.429223f, 0.525922f, 0.452328f, 0.539095f, 0.534372f, 0.563738f, 0.449120f, 0.451750f, 0.359685f, 0.553321f, 0.464992f, 0.524025f, 0.522507f, 0.582135f, 0.425124f, 0.459696f, 0.394679f, 0.519051f, 0.411226f, 0.539772f, 0.505003f, 0.587681f, 0.469383f, 0.451681f, 0.430062f, 0.541843f, 0.420929f, 0.542240f, 0.487570f, 0.567067f, 0.419708f, 0.456288f, 0.412096f, 0.527592f, 0.467870f, 0.545021f, 0.547842f, 0.573135f, 0.448166f, 0.581220f, 0.559255f, 0.469802f, 0.489935f, 0.557197f, 0.487135f, 0.377325f, 0.425637f, 0.582374f, 0.560738f, 0.425382f, 0.463129f, 0.549939f, 0.481810f, 0.350432f, 0.466049f, 0.593554f, 0.542315f, 0.482597f, 0.496969f, 0.518851f, 0.507807f, 0.366054f, 0.457476f, 0.569468f, 0.565965f, 0.444765f, 0.465404f, 0.515500f, 0.520271f, 0.337845f, 0.448357f, 0.586343f, 0.566462f, 0.444339f, 0.481474f, 0.557556f, 0.495837f, 0.368487f, 0.425850f, 0.580159f, 0.565990f, 0.400882f, 0.462578f, 0.551037f, 0.497924f, 0.338502f, 0.468483f, 0.592753f, 0.536897f, 0.481975f, 0.489485f, 0.519290f, 0.509298f, 0.366838f, 0.461538f, 0.567139f, 0.559419f, 0.458050f, 0.468739f, 0.514875f, 0.512271f, 0.346335f, 0.449357f, 0.583058f, 0.557532f, 0.454426f, 0.492673f, 0.551748f, 0.496414f, 0.364023f, 0.430048f, 0.579431f, 0.565100f, 0.420761f, 0.466297f, 0.551315f, 0.487418f, 0.348148f, 0.461136f, 0.585687f, 0.535194f, 0.485465f, 0.488622f, 0.513327f, 0.508844f, 0.368049f, 0.455823f, 0.554855f, 0.560589f, 0.456398f, 0.477641f, 0.507017f, 0.518069f, 0.338229f, 0.444624f, 0.500594f, 0.616610f, 0.439949f, 0.495561f, 0.569213f, 0.540425f, 0.422667f, 0.627919f, 0.514283f, 0.584446f, 0.441141f, 0.528331f, 0.577047f, 0.508969f, 0.372295f, 0.646734f, 0.536256f, 0.591823f, 0.428652f, 0.485852f, 0.592863f, 0.525360f, 0.399985f, 0.623408f, 0.552463f, 0.606841f, 0.448560f, 0.466321f, 0.600628f, 0.566464f, 0.356481f, 0.551351f, 0.548036f, 0.577907f, 0.427075f, 0.478978f, 0.581563f, 0.553606f, 0.426476f, 0.638442f, 0.498925f, 0.598346f, 0.444106f, 0.536998f, 0.575948f, 0.499260f, 0.371120f, 0.626981f, 0.545949f, 0.586548f, 0.428254f, 0.479753f, 0.596943f, 0.527697f, 0.401418f, 0.613028f, 0.542355f, 0.607063f, 0.447840f, 0.467102f, 0.603496f, 0.549575f, 0.364370f, 0.561534f, 0.532692f, 0.601573f, 0.425963f, 0.477495f, 0.573122f, 0.544325f, 0.422438f, 0.629794f, 0.512145f, 0.593241f, 0.436187f, 0.532146f, 0.582008f, 0.499410f, 0.366728f, 0.631277f, 0.550263f, 0.590346f, 0.430967f, 0.477189f, 0.600022f, 0.528313f, 0.406504f, 0.603355f, 0.537075f, 0.605495f, 0.437735f, 0.474413f, 0.601068f, 0.542204f, 0.348555f, 0.581430f, 0.499619f, 0.480920f, 0.536032f, 0.413380f, 0.478027f, 0.524393f, 0.490201f, 0.530954f, 0.517442f, 0.475326f, 0.541763f, 0.366450f, 0.498398f, 0.509411f, 0.503732f, 0.490468f, 0.488084f, 0.505941f, 0.554614f, 0.371690f, 0.503635f, 0.510325f, 0.557424f, 0.564303f, 0.534730f, 0.536543f, 0.563296f, 0.362277f, 0.498957f, 0.508357f, 0.538003f, 0.554638f, 0.514150f, 0.481676f, 0.543535f, 0.414778f, 0.478296f, 0.529467f, 0.496600f, 0.522262f, 0.522734f, 0.480361f, 0.534209f, 0.379264f, 0.485836f, 0.500082f, 0.498644f, 0.501901f, 0.474729f, 0.503193f, 0.560206f, 0.362595f, 0.515144f, 0.512647f, 0.557224f, 0.567242f, 0.539217f, 0.533273f, 0.538641f, 0.373064f, 0.495733f, 0.499786f, 0.532998f, 0.547731f, 0.506900f, 0.488792f, 0.539037f, 0.403028f, 0.494093f, 0.534739f, 0.494292f, 0.511628f, 0.528192f, 0.480037f, 0.546429f, 0.375120f, 0.484828f, 0.505006f, 0.495786f, 0.497935f, 0.502174f, 0.514122f, 0.541314f, 0.369540f, 0.493985f, 0.508263f, 0.550415f, 0.556157f, 0.543269f, 0.529970f, 0.562027f, 0.376526f, 0.499704f, 0.508621f, 0.536068f, 0.545993f}; // {2, 3, 18, 8} std::vector present_key = {0.938256f, 0.244570f, 0.458212f, 0.757407f, 0.203621f, 0.566312f, 0.185817f, 0.104736f, 0.116559f, 0.357639f, 0.004655f, 0.424854f, 0.664197f, 0.401688f, 0.085795f, 0.062689f, 0.278117f, 0.169313f, 0.965095f, 0.151230f, 0.805462f, 0.586108f, 0.569287f, 0.512081f, 0.971763f, 0.363845f, 0.787916f, 0.555294f, 0.395634f, 0.955466f, 0.598316f, 0.118917f, 0.417539f, 0.781582f, 0.693747f, 0.916340f, 0.259377f, 0.758194f, 0.459875f, 0.573610f, 0.955047f, 0.979286f, 0.861591f, 0.359097f, 0.887701f, 0.638609f, 0.429997f, 0.035743f, 0.770128f, 0.502106f, 0.786188f, 0.748023f, 0.793567f, 0.300651f, 0.800799f, 0.548846f, 0.473326f, 0.675126f, 0.021359f, 0.102317f, 0.292177f, 0.982990f, 0.139746f, 0.330596f, 0.051053f, 0.331269f, 0.320326f, 0.946807f, 0.845154f, 0.382764f, 0.024769f, 0.831031f, 0.660536f, 0.152364f, 0.996071f, 0.100233f, 0.867115f, 0.294266f, 0.435353f, 0.795457f, 0.677508f, 0.937864f, 0.621140f, 0.097810f, 0.884360f, 0.769156f, 0.711870f, 0.053734f, 0.396223f, 0.167436f, 0.821904f, 0.700529f, 0.883078f, 0.966575f, 0.774748f, 0.994233f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.614770f, 0.037130f, 0.014252f, 0.342104f, 0.823472f, 0.866135f, 0.960813f, 0.065121f, 0.044571f, 0.913284f, 0.305047f, 0.557987f, 0.982445f, 0.400449f, 0.665871f, 0.400880f, 0.768195f, 0.527715f, 0.237523f, 0.271306f, 0.258059f, 0.532320f, 0.703189f, 0.949280f, 0.694087f, 0.781193f, 0.168926f, 0.374063f, 0.413780f, 0.686380f, 0.295892f, 0.303292f, 0.355889f, 0.810302f, 0.577590f, 0.075277f, 0.078246f, 0.371287f, 0.766591f, 0.688683f, 0.707982f, 0.767210f, 0.287153f, 0.548256f, 0.543353f, 0.739632f, 0.956871f, 0.277990f, 0.793282f, 0.659971f, 0.580238f, 0.774880f, 0.944032f, 0.036691f, 0.147400f, 0.756287f, 0.083791f, 0.516124f, 0.219861f, 0.274296f, 0.701840f, 0.030193f, 0.873319f, 0.444479f, 0.502393f, 0.540048f, 0.645544f, 0.344857f, 0.101107f, 0.318379f, 0.168142f, 0.556133f, 0.318029f, 0.958067f, 0.965734f, 0.620126f, 0.617497f, 0.985379f, 0.887283f, 0.765070f, 0.313591f, 0.365539f, 0.201267f, 0.487148f, 0.990369f, 0.912151f, 0.118349f, 0.025190f, 0.898638f, 0.537170f, 0.200190f, 0.673653f, 0.644223f, 0.122086f, 0.259600f, 0.060078f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.209860f, 0.132306f, 0.193236f, 0.685467f, 0.049500f, 0.101855f, 0.134174f, 0.316541f, 0.298750f, 0.255064f, 0.750537f, 0.998023f, 0.533978f, 0.944203f, 0.396610f, 0.106682f, 0.408774f, 0.296128f, 0.493407f, 0.657044f, 0.461050f, 0.935161f, 0.884765f, 0.701978f, 0.489685f, 0.131687f, 0.397014f, 0.704402f, 0.284886f, 0.103988f, 0.907898f, 0.709051f, 0.615276f, 0.792499f, 0.835646f, 0.483459f, 0.881188f, 0.916419f, 0.271551f, 0.607545f, 0.526584f, 0.537946f, 0.937663f, 0.305189f, 0.983434f, 0.902131f, 0.458723f, 0.817453f, 0.769047f, 0.677895f, 0.319834f, 0.196451f, 0.671528f, 0.842973f, 0.016253f, 0.642803f, 0.442873f, 0.898088f, 0.321473f, 0.474185f, 0.514767f, 0.140440f, 0.712892f, 0.830476f, 0.057909f, 0.291389f, 0.038045f, 0.956544f, 0.667169f, 0.964200f, 0.531494f, 0.802069f, 0.374414f, 0.353819f, 0.378268f, 0.657862f, 0.359453f, 0.900367f, 0.983275f, 0.030427f, 0.193623f, 0.112250f, 0.042364f, 0.227741f, 0.446793f, 0.836990f, 0.221824f, 0.493945f, 0.929619f, 0.667215f, 0.798079f, 0.550994f, 0.980466f, 0.588662f, 0.045511f, 0.197983f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.404774f, 0.601277f, 0.771931f, 0.413086f, 0.710058f, 0.789869f, 0.317260f, 0.979270f, 0.649656f, 0.880998f, 0.555938f, 0.741603f, 0.770544f, 0.908248f, 0.150350f, 0.558283f, 0.428379f, 0.923159f, 0.105095f, 0.982574f, 0.875451f, 0.073826f, 0.490966f, 0.717560f, 0.738152f, 0.906494f, 0.799865f, 0.310930f, 0.498435f, 0.701786f, 0.138437f, 0.193991f, 0.481042f, 0.298246f, 0.862559f, 0.586277f, 0.348665f, 0.848833f, 0.804878f, 0.998355f, 0.847308f, 0.414457f, 0.127499f, 0.840641f, 0.059758f, 0.350271f, 0.919738f, 0.960766f, 0.640565f, 0.688648f, 0.042454f, 0.514480f, 0.546868f, 0.340101f, 0.068597f, 0.228908f, 0.357984f, 0.435142f, 0.590927f, 0.722392f, 0.317632f, 0.328954f, 0.019692f, 0.040875f, 0.257822f, 0.740245f, 0.628314f, 0.769789f, 0.768919f, 0.856567f, 0.720319f, 0.979011f, 0.898825f, 0.586717f, 0.588158f, 0.034267f, 0.998527f, 0.131576f, 0.740347f, 0.821015f, 0.373055f, 0.196852f, 0.098760f, 0.748606f, 0.452654f, 0.713718f, 0.915408f, 0.146584f, 0.919171f, 0.411626f, 0.305267f, 0.943062f, 0.990652f, 0.198892f, 0.656838f, 0.106495f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f, 0.650914f, 0.827313f, 0.684499f, 0.417333f, 0.383066f, 0.393122f, 0.589712f, 0.881567f, 0.929066f, 0.053530f, 0.181622f, 0.112224f, 0.193335f, 0.346608f, 0.506532f, 0.629461f, 0.732142f, 0.890112f, 0.989088f, 0.662856f, 0.845365f, 0.778039f, 0.307532f, 0.875692f, 0.042763f, 0.000367f, 0.273733f, 0.462098f, 0.638363f, 0.101770f, 0.673010f, 0.801816f, 0.185313f, 0.415125f, 0.519985f, 0.451807f, 0.799830f, 0.960522f, 0.798953f, 0.077993f, 0.804936f, 0.066596f, 0.235970f, 0.153097f, 0.197519f, 0.528315f, 0.671690f, 0.470321f, 0.959696f, 0.240292f, 0.763140f, 0.870182f, 0.562066f, 0.456223f, 0.596184f, 0.428810f, 0.555194f, 0.416934f, 0.400470f, 0.695346f, 0.092851f, 0.166542f, 0.851198f, 0.771077f, 0.281454f, 0.377269f, 0.926027f, 0.818077f, 0.614346f, 0.221490f, 0.044252f, 0.431258f, 0.672627f, 0.828480f, 0.852689f, 0.032776f, 0.244157f, 0.339095f, 0.188732f, 0.802975f, 0.767466f, 0.516833f, 0.982926f, 0.144059f, 0.899652f, 0.116463f, 0.163182f, 0.696219f, 0.109570f, 0.565845f, 0.420234f, 0.728474f, 0.900675f, 0.769872f, 0.849690f, 0.032945f, 0.966147f, 0.791880f, 0.675689f, 0.244889f, 0.216457f, 0.166048f, 0.922757f, 0.294077f, 0.453094f, 0.493958f, 0.778172f, 0.844235f, 0.139073f, 0.426904f, 0.842855f, 0.818033f, 0.102414f, 0.156383f, 0.304199f, 0.075359f, 0.424663f, 0.107618f, 0.568218f, 0.246557f, 0.596433f, 0.117526f, 0.975884f, 0.932561f, 0.391797f, 0.242179f, 0.250398f, 0.483394f, 0.039993f, 0.639705f, 0.408303f, 0.377407f, 0.809365f, 0.709035f, 0.954334f, 0.351936f, 0.897543f, 0.769967f, 0.357425f, 0.621665f, 0.288570f, 0.874400f, 0.112427f, 0.212434f, 0.310196f, 0.515433f, 0.415953f, 0.231255f, 0.307874f, 0.945431f, 0.294181f, 0.353904f, 0.003710f, 0.845078f, 0.154841f, 0.204144f, 0.255265f, 0.884622f, 0.206451f, 0.797526f, 0.808049f, 0.927021f, 0.115561f, 0.217279f, 0.742898f, 0.196001f, 0.286330f, 0.166742f, 0.172697f, 0.481553f, 0.109683f, 0.321698f, 0.426594f, 0.024548f, 0.388333f, 0.094122f, 0.493579f, 0.825738f, 0.818422f, 0.080449f, 0.601228f, 0.834586f, 0.237973f, 0.761927f, 0.890764f, 0.806124f, 0.107301f, 0.009060f, 0.191724f, 0.270477f, 0.616183f, 0.384273f, 0.703407f, 0.353075f, 0.154425f, 0.312690f, 0.884324f, 0.958532f, 0.207513f, 0.788468f, 0.273349f, 0.887132f, 0.165546f, 0.665960f, 0.084211f, 0.973893f, 0.700633f, 0.841816f, 0.566669f, 0.476801f, 0.621882f, 0.528742f, 0.469384f, 0.759450f, 0.178201f, 0.171172f, 0.431843f, 0.320748f, 0.074125f, 0.844471f, 0.771603f, 0.543921f, 0.979325f, 0.072600f, 0.766669f, 0.266370f, 0.368599f, 0.219279f, 0.789038f, 0.144240f, 0.840017f, 0.661578f, 0.059023f, 0.810982f, 0.627756f, 0.904982f, 0.748722f, 0.561121f, 0.836547f, 0.278050f, 0.183033f, 0.403026f, 0.745233f, 0.526907f, 0.487676f, 0.000546f, 0.425402f, 0.063554f, 0.208253f, 0.932394f, 0.215398f, 0.858338f, 0.802893f, 0.159146f, 0.605712f, 0.115662f, 0.727888f, 0.637462f, 0.811939f, 0.479385f, 0.914863f, 0.049349f, 0.292889f, 0.715053f, 0.418109f, 0.172951f, 0.107211f, 0.817339f, 0.473143f, 0.882284f, 0.733289f, 0.409726f, 0.373511f, 0.515638f, 0.889060f, 0.737279f, 0.005153f, 0.694158f, 0.919507f, 0.710456f, 0.177006f, 0.483518f, 0.140316f, 0.358995f, 0.937117f, 0.923305f, 0.282837f, 0.339631f}; // {2, 3, 18, 8} @@ -1116,7 +1149,7 @@ TEST(AttentionTest, Attention4DWithMask3DPastAndPresentQkMatmulCausal) { std::vector k = {-0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; // {2, 3, 6, 4} std::vector v = {-0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; - // {2, 1, 4, 13} + // {2, 1, 4, 18} std::vector m = {-0.454545f, -0.444930f, -0.435315f, -0.425699f, -0.416084f, -0.406469f, -0.396853f, -0.387238f, -0.377622f, -0.368007f, -0.358392f, -0.348776f, -0.339161f, -0.329545f, -0.319930f, -0.310315f, -0.300699f, -0.291084f, -0.281469f, -0.271853f, -0.262238f, -0.252622f, -0.243007f, -0.233392f, -0.223776f, -0.214161f, -0.204545f, -0.194930f, -0.185315f, -0.175699f, -0.166084f, -0.156469f, -0.146853f, -0.137238f, -0.127622f, -0.118007f, -0.108392f, -0.098776f, -0.089161f, -0.079545f, -0.069930f, -0.060315f, -0.050699f, -0.041084f, -0.031469f, -0.021853f, -0.012238f, -0.002622f, 0.006993f, 0.016608f, 0.026224f, 0.035839f, 0.045455f, 0.055070f, 0.064685f, 0.074301f, 0.083916f, 0.093531f, 0.103147f, 0.112762f, 0.122378f, 0.131993f, 0.141608f, 0.151224f, 0.160839f, 0.170455f, 0.180070f, 0.189685f, 0.199301f, 0.208916f, 0.218531f, 0.228147f, 0.237762f, 0.247378f, 0.256993f, 0.266608f, 0.276224f, 0.285839f, 0.295455f, 0.305070f, 0.314685f, 0.324301f, 0.333916f, 0.343531f, 0.353147f, 0.362762f, 0.372378f, 0.381993f, 0.391608f, 0.401224f, 0.410839f, 0.420455f, 0.430070f, 0.439685f, 0.449301f, 0.458916f, 0.468531f, 0.478147f, 0.487762f, 0.497378f, 0.506993f, 0.516608f, 0.526224f, 0.535839f}; // {2, 3, 12, 4} std::vector past_key = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f}; @@ -1132,7 +1165,7 @@ TEST(AttentionTest, Attention4DWithMask3DPastAndPresentQkMatmulCausal) { // {2, 3, 4, 4} std::vector y = {-0.393782f, -0.387694f, -0.381606f, -0.375519f, -0.397492f, -0.391304f, -0.385116f, -0.378928f, -0.397474f, -0.391207f, -0.384941f, -0.378674f, -0.394849f, -0.388519f, -0.382190f, -0.375860f, -0.226271f, -0.220186f, -0.214101f, -0.208016f, -0.230042f, -0.223857f, -0.217672f, -0.211488f, -0.230104f, -0.223841f, -0.217577f, -0.211314f, -0.227525f, -0.221197f, -0.214870f, -0.208543f, -0.058757f, -0.052674f, -0.046592f, -0.040510f, -0.062587f, -0.056406f, -0.050224f, -0.044042f, -0.062730f, -0.056470f, -0.050209f, -0.043949f, -0.060198f, -0.053873f, -0.047548f, -0.041223f, 0.108760f, 0.114840f, 0.120919f, 0.126999f, 0.104873f, 0.111051f, 0.117229f, 0.123408f, 0.104648f, 0.110906f, 0.117163f, 0.123421f, 0.107131f, 0.113454f, 0.119777f, 0.126099f, 0.276279f, 0.282356f, 0.288433f, 0.294510f, 0.272337f, 0.278512f, 0.284687f, 0.290862f, 0.272031f, 0.278286f, 0.284540f, 0.290794f, 0.274463f, 0.280783f, 0.287104f, 0.293424f, 0.443800f, 0.449874f, 0.455949f, 0.462023f, 0.439807f, 0.445978f, 0.452150f, 0.458321f, 0.439418f, 0.445669f, 0.451921f, 0.458172f, 0.441797f, 0.448115f, 0.454433f, 0.460751f}; - // {2, 3, 13, 4} + // {2, 3, 12, 4} std::vector present_key = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; // {2, 3, 18, 8} std::vector present_value = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; @@ -1151,28 +1184,28 @@ TEST(AttentionTest, Attention4DWithMask3DPastAndPresentQkMatmulCausal) { ); } -TEST(AttentionTest, Attention4DWithMask4DPastAndPresentQkMatmul) { - int batch_size = 2; // Q.shape[0] - int q_num_heads = 3; // Q.shape[1] - int q_sequence_length = 4; // Q.shape[2] - int head_size = 4; // Q.shape[3] - int kv_sequence_length = 6; // K.shape[2] and V.shape[2] - int kv_num_heads = 3; // K.shape[1] and V.shape[1] - int v_head_size = 4; // V.shape[3] - int past_sequence_length = 7; // past_key.shape[2] and past_value.shape[2] +TEST(AttentionTest, TestAttention4DWithPastAndPresentQkMatmulBias4DMaskCausal) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 3; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 8; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 8; // V.shape[3] + int past_sequence_length = 12; // past_key.shape[2] and past_value.shape[2] - // {2, 3, 4, 4} - std::vector q = {-0.454545f, -0.444129f, -0.433712f, -0.423295f, -0.412879f, -0.402462f, -0.392045f, -0.381629f, -0.371212f, -0.360795f, -0.350379f, -0.339962f, -0.329545f, -0.319129f, -0.308712f, -0.298295f, -0.287879f, -0.277462f, -0.267045f, -0.256629f, -0.246212f, -0.235795f, -0.225379f, -0.214962f, -0.204545f, -0.194129f, -0.183712f, -0.173295f, -0.162879f, -0.152462f, -0.142045f, -0.131629f, -0.121212f, -0.110795f, -0.100379f, -0.089962f, -0.079545f, -0.069129f, -0.058712f, -0.048295f, -0.037879f, -0.027462f, -0.017045f, -0.006629f, 0.003788f, 0.014205f, 0.024621f, 0.035038f, 0.045455f, 0.055871f, 0.066288f, 0.076705f, 0.087121f, 0.097538f, 0.107955f, 0.118371f, 0.128788f, 0.139205f, 0.149621f, 0.160038f, 0.170455f, 0.180871f, 0.191288f, 0.201705f, 0.212121f, 0.222538f, 0.232955f, 0.243371f, 0.253788f, 0.264205f, 0.274621f, 0.285038f, 0.295455f, 0.305871f, 0.316288f, 0.326705f, 0.337121f, 0.347538f, 0.357955f, 0.368371f, 0.378788f, 0.389205f, 0.399621f, 0.410038f, 0.420455f, 0.430871f, 0.441288f, 0.451705f, 0.462121f, 0.472538f, 0.482955f, 0.493371f, 0.503788f, 0.514205f, 0.524621f, 0.535038f}; - // {2, 3, 6, 4} - std::vector k = {-0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; + // {2, 3, 4, 8} + std::vector q = {0.548814f, 0.715189f, 0.602763f, 0.544883f, 0.423655f, 0.645894f, 0.437587f, 0.891773f, 0.963663f, 0.383442f, 0.791725f, 0.528895f, 0.568045f, 0.925597f, 0.071036f, 0.087129f, 0.020218f, 0.832620f, 0.778157f, 0.870012f, 0.978618f, 0.799159f, 0.461479f, 0.780529f, 0.118274f, 0.639921f, 0.143353f, 0.944669f, 0.521848f, 0.414662f, 0.264556f, 0.774234f, 0.456150f, 0.568434f, 0.018790f, 0.617635f, 0.612096f, 0.616934f, 0.943748f, 0.681820f, 0.359508f, 0.437032f, 0.697631f, 0.060225f, 0.666767f, 0.670638f, 0.210383f, 0.128926f, 0.315428f, 0.363711f, 0.570197f, 0.438602f, 0.988374f, 0.102045f, 0.208877f, 0.161310f, 0.653108f, 0.253292f, 0.466311f, 0.244426f, 0.158970f, 0.110375f, 0.656330f, 0.138183f, 0.196582f, 0.368725f, 0.820993f, 0.097101f, 0.837945f, 0.096098f, 0.976459f, 0.468651f, 0.976761f, 0.604846f, 0.739264f, 0.039188f, 0.282807f, 0.120197f, 0.296140f, 0.118728f, 0.317983f, 0.414263f, 0.064147f, 0.692472f, 0.566601f, 0.265390f, 0.523248f, 0.093941f, 0.575947f, 0.929296f, 0.318569f, 0.667410f, 0.131798f, 0.716327f, 0.289406f, 0.183191f, 0.586513f, 0.020108f, 0.828940f, 0.004695f, 0.677817f, 0.270008f, 0.735194f, 0.962189f, 0.248753f, 0.576157f, 0.592042f, 0.572252f, 0.223082f, 0.952749f, 0.447125f, 0.846409f, 0.699479f, 0.297437f, 0.813798f, 0.396506f, 0.881103f, 0.581273f, 0.881735f, 0.692532f, 0.725254f, 0.501324f, 0.956084f, 0.643990f, 0.423855f, 0.606393f, 0.019193f, 0.301575f, 0.660174f, 0.290078f, 0.618015f, 0.428769f, 0.135474f, 0.298282f, 0.569965f, 0.590873f, 0.574325f, 0.653201f, 0.652103f, 0.431418f, 0.896547f, 0.367562f, 0.435865f, 0.891923f, 0.806194f, 0.703889f, 0.100227f, 0.919483f, 0.714241f, 0.998847f, 0.149448f, 0.868126f, 0.162493f, 0.615560f, 0.123820f, 0.848008f, 0.807319f, 0.569101f, 0.407183f, 0.069167f, 0.697429f, 0.453543f, 0.722056f, 0.866382f, 0.975522f, 0.855803f, 0.011714f, 0.359978f, 0.729991f, 0.171630f, 0.521037f, 0.054338f, 0.199997f, 0.018522f, 0.793698f, 0.223925f, 0.345352f, 0.928081f, 0.704414f, 0.031839f, 0.164694f, 0.621478f, 0.577229f, 0.237893f, 0.934214f, 0.613966f, 0.535633f, 0.589910f, 0.730122f, 0.311945f, 0.398221f, 0.209844f}; + // {2, 3, 6, 8} + std::vector k = {0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f}; // {2, 3, 6, 4} - std::vector v = {-0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; - // {2, 3, 4, 13} - std::vector m = {-0.454545f, -0.451340f, -0.448135f, -0.444930f, -0.441725f, -0.438520f, -0.435315f, -0.432110f, -0.428904f, -0.425699f, -0.422494f, -0.419289f, -0.416084f, -0.412879f, -0.409674f, -0.406469f, -0.403263f, -0.400058f, -0.396853f, -0.393648f, -0.390443f, -0.387238f, -0.384033f, -0.380828f, -0.377622f, -0.374417f, -0.371212f, -0.368007f, -0.364802f, -0.361597f, -0.358392f, -0.355186f, -0.351981f, -0.348776f, -0.345571f, -0.342366f, -0.339161f, -0.335956f, -0.332751f, -0.329545f, -0.326340f, -0.323135f, -0.319930f, -0.316725f, -0.313520f, -0.310315f, -0.307110f, -0.303904f, -0.300699f, -0.297494f, -0.294289f, -0.291084f, -0.287879f, -0.284674f, -0.281469f, -0.278263f, -0.275058f, -0.271853f, -0.268648f, -0.265443f, -0.262238f, -0.259033f, -0.255828f, -0.252622f, -0.249417f, -0.246212f, -0.243007f, -0.239802f, -0.236597f, -0.233392f, -0.230186f, -0.226981f, -0.223776f, -0.220571f, -0.217366f, -0.214161f, -0.210956f, -0.207751f, -0.204545f, -0.201340f, -0.198135f, -0.194930f, -0.191725f, -0.188520f, -0.185315f, -0.182110f, -0.178904f, -0.175699f, -0.172494f, -0.169289f, -0.166084f, -0.162879f, -0.159674f, -0.156469f, -0.153263f, -0.150058f, -0.146853f, -0.143648f, -0.140443f, -0.137238f, -0.134033f, -0.130828f, -0.127622f, -0.124417f, -0.121212f, -0.118007f, -0.114802f, -0.111597f, -0.108392f, -0.105186f, -0.101981f, -0.098776f, -0.095571f, -0.092366f, -0.089161f, -0.085956f, -0.082751f, -0.079545f, -0.076340f, -0.073135f, -0.069930f, -0.066725f, -0.063520f, -0.060315f, -0.057110f, -0.053904f, -0.050699f, -0.047494f, -0.044289f, -0.041084f, -0.037879f, -0.034674f, -0.031469f, -0.028263f, -0.025058f, -0.021853f, -0.018648f, -0.015443f, -0.012238f, -0.009033f, -0.005828f, -0.002622f, 0.000583f, 0.003788f, 0.006993f, 0.010198f, 0.013403f, 0.016608f, 0.019814f, 0.023019f, 0.026224f, 0.029429f, 0.032634f, 0.035839f, 0.039044f, 0.042249f, 0.045455f, 0.048660f, 0.051865f, 0.055070f, 0.058275f, 0.061480f, 0.064685f, 0.067890f, 0.071096f, 0.074301f, 0.077506f, 0.080711f, 0.083916f, 0.087121f, 0.090326f, 0.093531f, 0.096737f, 0.099942f, 0.103147f, 0.106352f, 0.109557f, 0.112762f, 0.115967f, 0.119172f, 0.122378f, 0.125583f, 0.128788f, 0.131993f, 0.135198f, 0.138403f, 0.141608f, 0.144814f, 0.148019f, 0.151224f, 0.154429f, 0.157634f, 0.160839f, 0.164044f, 0.167249f, 0.170455f, 0.173660f, 0.176865f, 0.180070f, 0.183275f, 0.186480f, 0.189685f, 0.192890f, 0.196096f, 0.199301f, 0.202506f, 0.205711f, 0.208916f, 0.212121f, 0.215326f, 0.218531f, 0.221737f, 0.224942f, 0.228147f, 0.231352f, 0.234557f, 0.237762f, 0.240967f, 0.244172f, 0.247378f, 0.250583f, 0.253788f, 0.256993f, 0.260198f, 0.263403f, 0.266608f, 0.269814f, 0.273019f, 0.276224f, 0.279429f, 0.282634f, 0.285839f, 0.289044f, 0.292249f, 0.295455f, 0.298660f, 0.301865f, 0.305070f, 0.308275f, 0.311480f, 0.314685f, 0.317890f, 0.321096f, 0.324301f, 0.327506f, 0.330711f, 0.333916f, 0.337121f, 0.340326f, 0.343531f, 0.346737f, 0.349942f, 0.353147f, 0.356352f, 0.359557f, 0.362762f, 0.365967f, 0.369172f, 0.372378f, 0.375583f, 0.378788f, 0.381993f, 0.385198f, 0.388403f, 0.391608f, 0.394814f, 0.398019f, 0.401224f, 0.404429f, 0.407634f, 0.410839f, 0.414044f, 0.417249f, 0.420455f, 0.423660f, 0.426865f, 0.430070f, 0.433275f, 0.436480f, 0.439685f, 0.442890f, 0.446096f, 0.449301f, 0.452506f, 0.455711f, 0.458916f, 0.462121f, 0.465326f, 0.468531f, 0.471737f, 0.474942f, 0.478147f, 0.481352f, 0.484557f, 0.487762f, 0.490967f, 0.494172f, 0.497378f, 0.500583f, 0.503788f, 0.506993f, 0.510198f, 0.513403f, 0.516608f, 0.519814f, 0.523019f, 0.526224f, 0.529429f, 0.532634f, 0.535839f, 0.539044f, 0.542249f}; - // {2, 3, 12, 4} - std::vector past_key = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f}; - // {2, 3, 12, 4} - std::vector past_value = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f}; + std::vector v = {0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f}; + // {2, 3, 4, 18} + std::vector m = {0.966147f, 0.791880f, 0.675689f, 0.244889f, 0.216457f, 0.166048f, 0.922757f, 0.294077f, 0.453094f, 0.493958f, 0.778172f, 0.844235f, 0.139073f, 0.426904f, 0.842855f, 0.818033f, 0.102414f, 0.156383f, 0.304199f, 0.075359f, 0.424663f, 0.107618f, 0.568218f, 0.246557f, 0.596433f, 0.117526f, 0.975884f, 0.932561f, 0.391797f, 0.242179f, 0.250398f, 0.483394f, 0.039993f, 0.639705f, 0.408303f, 0.377407f, 0.809365f, 0.709035f, 0.954334f, 0.351936f, 0.897543f, 0.769967f, 0.357425f, 0.621665f, 0.288570f, 0.874400f, 0.112427f, 0.212434f, 0.183033f, 0.403026f, 0.745233f, 0.526907f, 0.487676f, 0.000546f, 0.425402f, 0.063554f, 0.208253f, 0.932394f, 0.215398f, 0.858338f, 0.802893f, 0.159146f, 0.605712f, 0.115662f, 0.727888f, 0.637462f, 0.811939f, 0.479385f, 0.914863f, 0.049349f, 0.292889f, 0.715053f, 0.418109f, 0.172951f, 0.107211f, 0.817339f, 0.473143f, 0.882284f, 0.733289f, 0.409726f, 0.373511f, 0.515638f, 0.889060f, 0.737279f, 0.005153f, 0.694158f, 0.919507f, 0.710456f, 0.177006f, 0.483518f, 0.140316f, 0.358995f, 0.937117f, 0.923305f, 0.282837f, 0.339631f, 0.600213f, 0.963197f, 0.147801f, 0.256917f, 0.873557f, 0.491892f, 0.898961f, 0.185518f, 0.532669f, 0.326270f, 0.316543f, 0.446877f, 0.433077f, 0.357347f, 0.914971f, 0.731744f, 0.727547f, 0.289913f, 0.577709f, 0.779179f, 0.795590f, 0.344530f, 0.770873f, 0.735894f, 0.141506f, 0.865945f, 0.441321f, 0.486410f, 0.448369f, 0.567846f, 0.621169f, 0.498180f, 0.866789f, 0.627735f, 0.401428f, 0.416692f, 0.810839f, 0.348192f, 0.211455f, 0.059383f, 0.876027f, 0.918546f, 0.120120f, 0.334474f, 0.175372f, 0.115898f, 0.899867f, 0.056877f, 0.980486f, 0.096451f, 0.863471f, 0.566506f, 0.367917f, 0.342342f, 0.757364f, 0.314573f, 0.657319f, 0.517326f, 0.484966f, 0.901162f, 0.554645f, 0.826862f, 0.725574f, 0.038557f, 0.773110f, 0.216870f, 0.903150f, 0.042924f, 0.333072f, 0.099733f, 0.475589f, 0.820022f, 0.298187f, 0.150935f, 0.330267f, 0.813880f, 0.140384f, 0.227362f, 0.068852f, 0.705710f, 0.395233f, 0.310840f, 0.718626f, 0.335978f, 0.727771f, 0.815199f, 0.217663f, 0.973819f, 0.162358f, 0.290841f, 0.179795f, 0.345506f, 0.480061f, 0.522176f, 0.853606f, 0.889448f, 0.220104f, 0.622894f, 0.111496f, 0.458970f, 0.322334f, 0.316501f, 0.482584f, 0.729828f, 0.069183f, 0.879173f, 0.734814f, 0.176499f, 0.939161f, 0.506312f, 0.999809f, 0.197259f, 0.534908f, 0.290248f, 0.304174f, 0.591065f, 0.921719f, 0.805264f, 0.723941f, 0.559174f, 0.922298f, 0.492361f, 0.873832f, 0.833982f, 0.213835f, 0.771225f, 0.012171f, 0.322830f, 0.229567f, 0.506863f, 0.736853f, 0.097676f, 0.514922f, 0.938412f, 0.228647f, 0.677141f, 0.592880f, 0.010064f, 0.475826f, 0.708770f, 0.043975f, 0.879521f, 0.520081f, 0.030661f, 0.224414f, 0.953676f, 0.582320f, 0.107473f, 0.287544f, 0.456704f, 0.020950f, 0.411616f, 0.489459f, 0.243678f, 0.588639f, 0.753240f, 0.235834f, 0.620500f, 0.639622f, 0.948540f, 0.778276f, 0.848345f, 0.490420f, 0.185349f, 0.995815f, 0.129356f, 0.471457f, 0.068093f, 0.943851f, 0.964925f, 0.719389f, 0.349993f, 0.254382f, 0.265303f, 0.127294f, 0.525809f, 0.141817f, 0.316731f, 0.626706f, 0.727544f, 0.024273f, 0.430116f, 0.652125f, 0.853246f, 0.475325f, 0.969206f, 0.265633f, 0.013509f, 0.483753f, 0.256114f, 0.823718f, 0.232773f, 0.310629f, 0.791227f, 0.715143f, 0.558051f, 0.704948f, 0.418637f, 0.005310f, 0.011355f, 0.511222f, 0.083291f, 0.051075f, 0.965517f, 0.859003f, 0.152027f, 0.000664f, 0.941668f, 0.278325f, 0.185898f, 0.691508f, 0.108904f, 0.264650f, 0.975095f, 0.639463f, 0.520678f, 0.397919f, 0.774501f, 0.140957f, 0.967338f, 0.861123f, 0.617657f, 0.042906f, 0.700856f, 0.913284f, 0.524577f, 0.354225f, 0.120277f, 0.754901f, 0.885022f, 0.100252f, 0.758985f, 0.017060f, 0.967055f, 0.615058f, 0.552439f, 0.295950f, 0.929292f, 0.265906f, 0.828147f, 0.985109f, 0.783397f, 0.518990f, 0.066074f, 0.472414f, 0.438256f, 0.202796f, 0.423588f, 0.357758f, 0.163684f, 0.441374f, 0.262800f, 0.522062f, 0.035160f, 0.906231f, 0.816364f, 0.552581f, 0.851809f, 0.962395f, 0.110522f, 0.630832f, 0.997994f, 0.987889f, 0.603323f, 0.128021f, 0.583193f, 0.002065f, 0.198911f, 0.956123f, 0.330441f, 0.638390f, 0.280860f, 0.947822f, 0.728559f, 0.329651f, 0.791761f, 0.108166f, 0.392319f, 0.221218f, 0.683726f, 0.102446f, 0.397026f, 0.276650f, 0.506343f, 0.349898f, 0.706411f, 0.024577f, 0.633987f, 0.230571f, 0.268709f, 0.800256f, 0.955568f, 0.316550f, 0.826805f, 0.103991f, 0.633982f, 0.751032f, 0.155978f, 0.426002f, 0.892707f, 0.103578f, 0.018096f, 0.590585f, 0.435532f, 0.798689f, 0.923456f, 0.299154f, 0.388404f, 0.486272f, 0.588151f, 0.983854f, 0.697330f, 0.389549f, 0.263768f, 0.944626f, 0.135548f, 0.720266f, 0.925395f, 0.664666f, 0.423054f, 0.198991f, 0.367475f, 0.706872f, 0.649534f, 0.927976f, 0.866861f, 0.816151f, 0.911451f, 0.276337f, 0.369524f, 0.379894f, 0.560451f, 0.668218f, 0.286717f, 0.019462f, 0.399222f}; + // {2, 3, 12, 8} + std::vector past_key = {0.308528f, 0.942185f, 0.888265f, 0.860311f, 0.653000f, 0.344289f, 0.548849f, 0.815225f, 0.098610f, 0.801075f, 0.041180f, 0.816421f, 0.807564f, 0.051007f, 0.627161f, 0.502453f, 0.169820f, 0.148379f, 0.773259f, 0.567693f, 0.982999f, 0.982248f, 0.992667f, 0.118616f, 0.938256f, 0.244570f, 0.458212f, 0.757407f, 0.203621f, 0.566312f, 0.185817f, 0.104736f, 0.116559f, 0.357639f, 0.004655f, 0.424854f, 0.664197f, 0.401688f, 0.085795f, 0.062689f, 0.278117f, 0.169313f, 0.965095f, 0.151230f, 0.805462f, 0.586108f, 0.569287f, 0.512081f, 0.971763f, 0.363845f, 0.787916f, 0.555294f, 0.395634f, 0.955466f, 0.598316f, 0.118917f, 0.417539f, 0.781582f, 0.693747f, 0.916340f, 0.259377f, 0.758194f, 0.459875f, 0.573610f, 0.955047f, 0.979286f, 0.861591f, 0.359097f, 0.887701f, 0.638609f, 0.429997f, 0.035743f, 0.770128f, 0.502106f, 0.786188f, 0.748023f, 0.793567f, 0.300651f, 0.800799f, 0.548846f, 0.473326f, 0.675126f, 0.021359f, 0.102317f, 0.292177f, 0.982990f, 0.139746f, 0.330596f, 0.051053f, 0.331269f, 0.320326f, 0.946807f, 0.845154f, 0.382764f, 0.024769f, 0.831031f, 0.660536f, 0.152364f, 0.996071f, 0.100233f, 0.867115f, 0.294266f, 0.435353f, 0.795457f, 0.677508f, 0.937864f, 0.621140f, 0.097810f, 0.884360f, 0.769156f, 0.711870f, 0.053734f, 0.396223f, 0.167436f, 0.821904f, 0.700529f, 0.883078f, 0.966575f, 0.774748f, 0.994233f, 0.614770f, 0.037130f, 0.014252f, 0.342104f, 0.823472f, 0.866135f, 0.960813f, 0.065121f, 0.044571f, 0.913284f, 0.305047f, 0.557987f, 0.982445f, 0.400449f, 0.665871f, 0.400880f, 0.768195f, 0.527715f, 0.237523f, 0.271306f, 0.258059f, 0.532320f, 0.703189f, 0.949280f, 0.694087f, 0.781193f, 0.168926f, 0.374063f, 0.413780f, 0.686380f, 0.295892f, 0.303292f, 0.355889f, 0.810302f, 0.577590f, 0.075277f, 0.078246f, 0.371287f, 0.766591f, 0.688683f, 0.707982f, 0.767210f, 0.287153f, 0.548256f, 0.543353f, 0.739632f, 0.956871f, 0.277990f, 0.793282f, 0.659971f, 0.580238f, 0.774880f, 0.944032f, 0.036691f, 0.147400f, 0.756287f, 0.083791f, 0.516124f, 0.219861f, 0.274296f, 0.701840f, 0.030193f, 0.873319f, 0.444479f, 0.502393f, 0.540048f, 0.645544f, 0.344857f, 0.101107f, 0.318379f, 0.168142f, 0.556133f, 0.318029f, 0.958067f, 0.965734f, 0.620126f, 0.617497f, 0.985379f, 0.887283f, 0.765070f, 0.313591f, 0.365539f, 0.201267f, 0.487148f, 0.990369f, 0.912151f, 0.118349f, 0.025190f, 0.898638f, 0.537170f, 0.200190f, 0.673653f, 0.644223f, 0.122086f, 0.259600f, 0.060078f, 0.209860f, 0.132306f, 0.193236f, 0.685467f, 0.049500f, 0.101855f, 0.134174f, 0.316541f, 0.298750f, 0.255064f, 0.750537f, 0.998023f, 0.533978f, 0.944203f, 0.396610f, 0.106682f, 0.408774f, 0.296128f, 0.493407f, 0.657044f, 0.461050f, 0.935161f, 0.884765f, 0.701978f, 0.489685f, 0.131687f, 0.397014f, 0.704402f, 0.284886f, 0.103988f, 0.907898f, 0.709051f, 0.615276f, 0.792499f, 0.835646f, 0.483459f, 0.881188f, 0.916419f, 0.271551f, 0.607545f, 0.526584f, 0.537946f, 0.937663f, 0.305189f, 0.983434f, 0.902131f, 0.458723f, 0.817453f, 0.769047f, 0.677895f, 0.319834f, 0.196451f, 0.671528f, 0.842973f, 0.016253f, 0.642803f, 0.442873f, 0.898088f, 0.321473f, 0.474185f, 0.514767f, 0.140440f, 0.712892f, 0.830476f, 0.057909f, 0.291389f, 0.038045f, 0.956544f, 0.667169f, 0.964200f, 0.531494f, 0.802069f, 0.374414f, 0.353819f, 0.378268f, 0.657862f, 0.359453f, 0.900367f, 0.983275f, 0.030427f, 0.193623f, 0.112250f, 0.042364f, 0.227741f, 0.446793f, 0.836990f, 0.221824f, 0.493945f, 0.929619f, 0.667215f, 0.798079f, 0.550994f, 0.980466f, 0.588662f, 0.045511f, 0.197983f, 0.404774f, 0.601277f, 0.771931f, 0.413086f, 0.710058f, 0.789869f, 0.317260f, 0.979270f, 0.649656f, 0.880998f, 0.555938f, 0.741603f, 0.770544f, 0.908248f, 0.150350f, 0.558283f, 0.428379f, 0.923159f, 0.105095f, 0.982574f, 0.875451f, 0.073826f, 0.490966f, 0.717560f, 0.738152f, 0.906494f, 0.799865f, 0.310930f, 0.498435f, 0.701786f, 0.138437f, 0.193991f, 0.481042f, 0.298246f, 0.862559f, 0.586277f, 0.348665f, 0.848833f, 0.804878f, 0.998355f, 0.847308f, 0.414457f, 0.127499f, 0.840641f, 0.059758f, 0.350271f, 0.919738f, 0.960766f, 0.640565f, 0.688648f, 0.042454f, 0.514480f, 0.546868f, 0.340101f, 0.068597f, 0.228908f, 0.357984f, 0.435142f, 0.590927f, 0.722392f, 0.317632f, 0.328954f, 0.019692f, 0.040875f, 0.257822f, 0.740245f, 0.628314f, 0.769789f, 0.768919f, 0.856567f, 0.720319f, 0.979011f, 0.898825f, 0.586717f, 0.588158f, 0.034267f, 0.998527f, 0.131576f, 0.740347f, 0.821015f, 0.373055f, 0.196852f, 0.098760f, 0.748606f, 0.452654f, 0.713718f, 0.915408f, 0.146584f, 0.919171f, 0.411626f, 0.305267f, 0.943062f, 0.990652f, 0.198892f, 0.656838f, 0.106495f, 0.650914f, 0.827313f, 0.684499f, 0.417333f, 0.383066f, 0.393122f, 0.589712f, 0.881567f, 0.929066f, 0.053530f, 0.181622f, 0.112224f, 0.193335f, 0.346608f, 0.506532f, 0.629461f, 0.732142f, 0.890112f, 0.989088f, 0.662856f, 0.845365f, 0.778039f, 0.307532f, 0.875692f, 0.042763f, 0.000367f, 0.273733f, 0.462098f, 0.638363f, 0.101770f, 0.673010f, 0.801816f, 0.185313f, 0.415125f, 0.519985f, 0.451807f, 0.799830f, 0.960522f, 0.798953f, 0.077993f, 0.804936f, 0.066596f, 0.235970f, 0.153097f, 0.197519f, 0.528315f, 0.671690f, 0.470321f, 0.959696f, 0.240292f, 0.763140f, 0.870182f, 0.562066f, 0.456223f, 0.596184f, 0.428810f, 0.555194f, 0.416934f, 0.400470f, 0.695346f, 0.092851f, 0.166542f, 0.851198f, 0.771077f, 0.281454f, 0.377269f, 0.926027f, 0.818077f, 0.614346f, 0.221490f, 0.044252f, 0.431258f, 0.672627f, 0.828480f, 0.852689f, 0.032776f, 0.244157f, 0.339095f, 0.188732f, 0.802975f, 0.767466f, 0.516833f, 0.982926f, 0.144059f, 0.899652f, 0.116463f, 0.163182f, 0.696219f, 0.109570f, 0.565845f, 0.420234f, 0.728474f, 0.900675f, 0.769872f, 0.849690f, 0.032945f, 0.310196f, 0.515433f, 0.415953f, 0.231255f, 0.307874f, 0.945431f, 0.294181f, 0.353904f, 0.003710f, 0.845078f, 0.154841f, 0.204144f, 0.255265f, 0.884622f, 0.206451f, 0.797526f, 0.808049f, 0.927021f, 0.115561f, 0.217279f, 0.742898f, 0.196001f, 0.286330f, 0.166742f, 0.172697f, 0.481553f, 0.109683f, 0.321698f, 0.426594f, 0.024548f, 0.388333f, 0.094122f, 0.493579f, 0.825738f, 0.818422f, 0.080449f, 0.601228f, 0.834586f, 0.237973f, 0.761927f, 0.890764f, 0.806124f, 0.107301f, 0.009060f, 0.191724f, 0.270477f, 0.616183f, 0.384273f, 0.703407f, 0.353075f, 0.154425f, 0.312690f, 0.884324f, 0.958532f, 0.207513f, 0.788468f, 0.273349f, 0.887132f, 0.165546f, 0.665960f, 0.084211f, 0.973893f, 0.700633f, 0.841816f, 0.566669f, 0.476801f, 0.621882f, 0.528742f, 0.469384f, 0.759450f, 0.178201f, 0.171172f}; + // {2, 3, 12, 8} + std::vector past_value = {0.431843f, 0.320748f, 0.074125f, 0.844471f, 0.771603f, 0.543921f, 0.979325f, 0.072600f, 0.766669f, 0.266370f, 0.368599f, 0.219279f, 0.789038f, 0.144240f, 0.840017f, 0.661578f, 0.059023f, 0.810982f, 0.627756f, 0.904982f, 0.748722f, 0.561121f, 0.836547f, 0.278050f, 0.546950f, 0.293617f, 0.968204f, 0.226196f, 0.015738f, 0.325855f, 0.502509f, 0.028363f, 0.559248f, 0.874283f, 0.704732f, 0.622968f, 0.955962f, 0.958279f, 0.824266f, 0.607742f, 0.487765f, 0.013316f, 0.606262f, 0.989088f, 0.818101f, 0.340605f, 0.152047f, 0.784059f, 0.743938f, 0.967047f, 0.874842f, 0.555663f, 0.101284f, 0.483501f, 0.313695f, 0.512408f, 0.301702f, 0.861823f, 0.844327f, 0.315465f, 0.599581f, 0.430181f, 0.909093f, 0.187361f, 0.697728f, 0.970375f, 0.175276f, 0.201966f, 0.693723f, 0.779154f, 0.490549f, 0.609686f, 0.212682f, 0.476614f, 0.112072f, 0.321422f, 0.284780f, 0.444625f, 0.930126f, 0.181268f, 0.401388f, 0.615597f, 0.946557f, 0.133148f, 0.917877f, 0.081054f, 0.480741f, 0.454590f, 0.209603f, 0.347460f, 0.454165f, 0.865211f, 0.955064f, 0.518926f, 0.870100f, 0.608172f, 0.349087f, 0.194194f, 0.413135f, 0.522824f, 0.044443f, 0.145841f, 0.600184f, 0.225002f, 0.837326f, 0.326942f, 0.104834f, 0.083531f, 0.937123f, 0.118020f, 0.140910f, 0.862666f, 0.254288f, 0.665951f, 0.816726f, 0.607181f, 0.957489f, 0.708883f, 0.112752f, 0.558410f, 0.718186f, 0.801957f, 0.026321f, 0.718879f, 0.825681f, 0.746834f, 0.512349f, 0.458021f, 0.549419f, 0.704644f, 0.922914f, 0.617035f, 0.887834f, 0.701257f, 0.068336f, 0.500828f, 0.286486f, 0.285175f, 0.355928f, 0.314733f, 0.578610f, 0.683601f, 0.268749f, 0.129763f, 0.058809f, 0.575753f, 0.186130f, 0.009248f, 0.927753f, 0.537140f, 0.092448f, 0.842921f, 0.983203f, 0.448601f, 0.042490f, 0.117546f, 0.381654f, 0.885523f, 0.148039f, 0.823990f, 0.014976f, 0.457389f, 0.644397f, 0.060379f, 0.614763f, 0.944404f, 0.160260f, 0.729611f, 0.609094f, 0.185116f, 0.006203f, 0.009284f, 0.532092f, 0.942779f, 0.644299f, 0.714300f, 0.493865f, 0.581889f, 0.126368f, 0.876821f, 0.760793f, 0.998199f, 0.297723f, 0.227018f, 0.125162f, 0.964210f, 0.780885f, 0.166325f, 0.552686f, 0.413768f, 0.151486f, 0.162073f, 0.963470f, 0.304964f, 0.941439f, 0.075611f, 0.460803f, 0.129619f, 0.004787f, 0.553766f, 0.113894f, 0.722025f, 0.698116f, 0.176333f, 0.941742f, 0.721043f, 0.297970f, 0.709234f, 0.731930f, 0.342226f, 0.375589f, 0.359107f, 0.616618f, 0.900410f, 0.173193f, 0.875200f, 0.027653f, 0.660339f, 0.414439f, 0.791282f, 0.721198f, 0.480108f, 0.643864f, 0.501773f, 0.811518f, 0.476084f, 0.523156f, 0.250521f, 0.605043f, 0.302905f, 0.577284f, 0.169678f, 0.159469f, 0.417030f, 0.426820f, 0.268109f, 0.131597f, 0.039211f, 0.025232f, 0.271550f, 0.461853f, 0.726243f, 0.474872f, 0.904051f, 0.035220f, 0.180661f, 0.338515f, 0.577496f, 0.852736f, 0.350202f, 0.267989f, 0.061889f, 0.821303f, 0.379666f, 0.571550f, 0.983555f, 0.001595f, 0.145450f, 0.779111f, 0.805128f, 0.769247f, 0.536999f, 0.978857f, 0.396185f, 0.601944f, 0.063369f, 0.409857f, 0.722500f, 0.238739f, 0.943828f, 0.686783f, 0.287575f, 0.768999f, 0.083165f, 0.974774f, 0.049285f, 0.933456f, 0.252854f, 0.757824f, 0.000074f, 0.254240f, 0.749101f, 0.532336f, 0.114952f, 0.393630f, 0.375549f, 0.568162f, 0.667977f, 0.840830f, 0.497231f, 0.392022f, 0.143977f, 0.804823f, 0.713370f, 0.408677f, 0.518432f, 0.665183f, 0.164806f, 0.027198f, 0.317504f, 0.595585f, 0.486606f, 0.692555f, 0.819690f, 0.488442f, 0.134267f, 0.850628f, 0.574990f, 0.739937f, 0.704665f, 0.968212f, 0.295307f, 0.705307f, 0.365676f, 0.395411f, 0.230595f, 0.344010f, 0.948297f, 0.292571f, 0.245991f, 0.583138f, 0.258036f, 0.473386f, 0.834176f, 0.230400f, 0.426691f, 0.610490f, 0.545629f, 0.974723f, 0.680370f, 0.739946f, 0.966956f, 0.414438f, 0.355380f, 0.043862f, 0.184204f, 0.237190f, 0.183504f, 0.754784f, 0.535883f, 0.667634f, 0.820462f, 0.230774f, 0.325924f, 0.708360f, 0.392759f, 0.029271f, 0.434955f, 0.908273f, 0.409021f, 0.332249f, 0.989525f, 0.644416f, 0.365998f, 0.102020f, 0.787849f, 0.708075f, 0.921916f, 0.217276f, 0.114924f, 0.724073f, 0.203396f, 0.176104f, 0.319807f, 0.816825f, 0.539537f, 0.045850f, 0.463895f, 0.683980f, 0.538368f, 0.572450f, 0.224777f, 0.847739f, 0.561399f, 0.713246f, 0.981864f, 0.428199f, 0.881067f, 0.007281f, 0.033407f, 0.590280f, 0.311449f, 0.248277f, 0.277935f, 0.318403f, 0.728948f, 0.569196f, 0.789036f, 0.830197f, 0.842935f, 0.414644f, 0.421273f, 0.926266f, 0.661764f, 0.080467f, 0.542187f, 0.356007f, 0.987435f, 0.013655f, 0.612181f, 0.723623f, 0.288907f, 0.973642f, 0.859537f, 0.915653f, 0.019232f, 0.569872f, 0.294650f, 0.849029f, 0.632850f, 0.538877f, 0.114588f, 0.540223f, 0.631904f, 0.955912f, 0.585051f, 0.967401f, 0.961606f, 0.650200f, 0.505908f, 0.466022f, 0.890379f, 0.028257f, 0.113808f, 0.102072f, 0.756935f, 0.339651f, 0.637969f, 0.603783f, 0.385828f, 0.531568f, 0.645139f, 0.940950f, 0.575634f, 0.614367f, 0.067856f, 0.952216f, 0.528082f, 0.801273f, 0.050291f, 0.420910f, 0.256975f, 0.266976f, 0.791454f, 0.623867f, 0.439745f, 0.010586f, 0.964928f, 0.962023f, 0.217552f, 0.041346f, 0.530199f, 0.951411f, 0.910396f, 0.584663f, 0.303549f, 0.329961f, 0.897914f, 0.491784f, 0.131116f, 0.248425f, 0.276795f, 0.123547f, 0.463044f, 0.916051f, 0.668783f, 0.072474f, 0.005495f, 0.276248f, 0.362693f, 0.776750f, 0.967006f, 0.387567f, 0.686690f, 0.994902f, 0.745667f, 0.636190f, 0.078075f, 0.323215f, 0.913392f, 0.201005f, 0.843590f, 0.696324f, 0.366324f, 0.529174f, 0.542806f, 0.714054f, 0.516556f, 0.133076f, 0.773455f, 0.406273f, 0.963094f, 0.283514f, 0.263079f, 0.333507f, 0.572317f, 0.894870f, 0.176282f, 0.279679f, 0.581680f, 0.454334f, 0.447323f, 0.820734f, 0.923878f, 0.481307f, 0.687352f, 0.801059f, 0.518366f, 0.294316f, 0.638085f, 0.585109f, 0.901563f, 0.052407f, 0.910131f, 0.534432f, 0.015676f, 0.344702f, 0.724334f, 0.488433f, 0.980159f, 0.422610f, 0.326635f, 0.821672f, 0.547907f, 0.682327f, 0.805702f, 0.671428f, 0.422408f, 0.124796f, 0.580248f, 0.897433f, 0.418892f, 0.910725f, 0.503528f, 0.620842f, 0.832989f, 0.564597f, 0.090969f, 0.980979f, 0.245849f, 0.710505f, 0.505113f, 0.478773f, 0.243941f, 0.722151f, 0.112788f, 0.990453f, 0.845374f, 0.534509f, 0.424553f, 0.286465f, 0.501591f, 0.879417f, 0.275006f, 0.500537f, 0.234550f, 0.337149f, 0.190261f, 0.990539f, 0.571497f, 0.732815f, 0.098250f, 0.366118f, 0.892640f, 0.084438f, 0.165483f, 0.625418f, 0.622789f, 0.838227f, 0.935493f, 0.141986f, 0.259374f, 0.427461f, 0.000903f, 0.069814f, 0.226491f}; ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); @@ -1181,14 +1214,15 @@ TEST(AttentionTest, Attention4DWithMask4DPastAndPresentQkMatmul) { ASSERT_EQ(past_key.size(), batch_size * kv_num_heads * past_sequence_length * head_size); ASSERT_EQ(past_value.size(), batch_size * kv_num_heads * past_sequence_length * v_head_size); - // {2, 3, 4, 4} - std::vector y = {-0.385742f, -0.379327f, -0.372911f, -0.366496f, -0.385554f, -0.379139f, -0.372723f, -0.366308f, -0.385366f, -0.378950f, -0.372535f, -0.366119f, -0.385178f, -0.378762f, -0.372347f, -0.365931f, -0.218323f, -0.211907f, -0.205492f, -0.199076f, -0.218134f, -0.211719f, -0.205304f, -0.198888f, -0.217946f, -0.211531f, -0.205115f, -0.198700f, -0.217758f, -0.211342f, -0.204927f, -0.198512f, -0.050903f, -0.044487f, -0.038072f, -0.031657f, -0.050715f, -0.044299f, -0.037884f, -0.031468f, -0.050526f, -0.044111f, -0.037695f, -0.031280f, -0.050338f, -0.043922f, -0.037507f, -0.031092f, 0.116517f, 0.122932f, 0.129348f, 0.135763f, 0.116705f, 0.123121f, 0.129536f, 0.135952f, 0.116894f, 0.123309f, 0.129724f, 0.136140f, 0.117082f, 0.123497f, 0.129913f, 0.136328f, 0.283937f, 0.290352f, 0.296768f, 0.303183f, 0.284125f, 0.290540f, 0.296956f, 0.303371f, 0.284313f, 0.290729f, 0.297144f, 0.303559f, 0.284501f, 0.290917f, 0.297332f, 0.303747f, 0.451356f, 0.457772f, 0.464187f, 0.470602f, 0.451544f, 0.457960f, 0.464375f, 0.470790f, 0.451732f, 0.458148f, 0.464563f, 0.470978f, 0.451920f, 0.458336f, 0.464751f, 0.471166f}; - // {2, 3, 13, 4} - std::vector present_key = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; + // {2, 3, 4, 8} + std::vector y = {0.431265f, 0.558994f, 0.492979f, 0.535281f, 0.609591f, 0.466737f, 0.692090f, 0.412591f, 0.468058f, 0.623595f, 0.468127f, 0.483497f, 0.577278f, 0.512802f, 0.639767f, 0.427679f, 0.422704f, 0.532822f, 0.449594f, 0.560548f, 0.608427f, 0.476187f, 0.695694f, 0.425740f, 0.447270f, 0.528366f, 0.506840f, 0.501836f, 0.547248f, 0.457381f, 0.583533f, 0.471707f, 0.414727f, 0.517263f, 0.342732f, 0.363543f, 0.677046f, 0.664675f, 0.271455f, 0.479982f, 0.438313f, 0.537211f, 0.342649f, 0.402609f, 0.660072f, 0.631518f, 0.266481f, 0.501402f, 0.458457f, 0.519536f, 0.434125f, 0.443849f, 0.614893f, 0.636419f, 0.310940f, 0.497030f, 0.433312f, 0.522457f, 0.417441f, 0.405432f, 0.617509f, 0.592985f, 0.310558f, 0.490073f, 0.499459f, 0.430465f, 0.601451f, 0.404111f, 0.502848f, 0.415186f, 0.440655f, 0.478187f, 0.536562f, 0.376663f, 0.527310f, 0.363608f, 0.443744f, 0.476396f, 0.453812f, 0.498910f, 0.483497f, 0.433209f, 0.541590f, 0.366029f, 0.513807f, 0.477506f, 0.492110f, 0.527910f, 0.471458f, 0.419741f, 0.536529f, 0.407806f, 0.512188f, 0.467064f, 0.496260f, 0.519270f, 0.683252f, 0.426643f, 0.425275f, 0.457410f, 0.611686f, 0.591234f, 0.394568f, 0.446171f, 0.637484f, 0.426481f, 0.346779f, 0.466867f, 0.585075f, 0.558250f, 0.387627f, 0.507636f, 0.658808f, 0.467355f, 0.496107f, 0.556756f, 0.513309f, 0.520842f, 0.411220f, 0.451704f, 0.661693f, 0.463543f, 0.421647f, 0.486068f, 0.552701f, 0.484705f, 0.412050f, 0.449818f, 0.637941f, 0.564086f, 0.543446f, 0.530844f, 0.627347f, 0.520370f, 0.389963f, 0.520054f, 0.574335f, 0.604007f, 0.468559f, 0.473710f, 0.559229f, 0.504183f, 0.453090f, 0.564618f, 0.568083f, 0.541180f, 0.491888f, 0.485970f, 0.564150f, 0.506989f, 0.421426f, 0.544228f, 0.616426f, 0.467555f, 0.529898f, 0.487372f, 0.574411f, 0.471969f, 0.388121f, 0.485012f, 0.533687f, 0.523210f, 0.560021f, 0.490233f, 0.443149f, 0.420163f, 0.538998f, 0.606965f, 0.586616f, 0.478324f, 0.572142f, 0.517933f, 0.441955f, 0.411890f, 0.550505f, 0.604577f, 0.541173f, 0.473423f, 0.505749f, 0.473388f, 0.389025f, 0.498730f, 0.507861f, 0.584389f, 0.519963f, 0.461030f, 0.576878f, 0.471281f, 0.461238f, 0.496673f, 0.509573f, 0.568405f}; // {2, 3, 18, 8} - std::vector present_value = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; - // {2, 3, 4, 13} - std::vector qk_matmul = {0.391336f, 0.370435f, 0.349534f, 0.328633f, 0.307732f, 0.286831f, 0.265930f, 0.390055f, 0.365671f, 0.341286f, 0.316902f, 0.292517f, 0.268133f, 0.354201f, 0.335284f, 0.316367f, 0.297450f, 0.278534f, 0.259617f, 0.240700f, 0.353045f, 0.330975f, 0.308905f, 0.286836f, 0.264766f, 0.242696f, 0.317066f, 0.300134f, 0.283201f, 0.266268f, 0.249335f, 0.232403f, 0.215470f, 0.316034f, 0.296279f, 0.276524f, 0.256769f, 0.237014f, 0.217260f, 0.279932f, 0.264983f, 0.250034f, 0.235086f, 0.220137f, 0.205189f, 0.190240f, 0.279023f, 0.261583f, 0.244143f, 0.226703f, 0.209263f, 0.191823f, 0.152046f, 0.139081f, 0.126117f, 0.113152f, 0.100188f, 0.087223f, 0.074259f, 0.151261f, 0.136136f, 0.121011f, 0.105885f, 0.090760f, 0.075635f, 0.128800f, 0.117819f, 0.106839f, 0.095859f, 0.084878f, 0.073898f, 0.062918f, 0.128139f, 0.115329f, 0.102518f, 0.089708f, 0.076898f, 0.064087f, 0.105554f, 0.096558f, 0.087561f, 0.078565f, 0.069569f, 0.060573f, 0.051577f, 0.105017f, 0.094522f, 0.084026f, 0.073531f, 0.063035f, 0.052539f, 0.082308f, 0.075296f, 0.068284f, 0.061272f, 0.054260f, 0.047248f, 0.040235f, 0.081896f, 0.073715f, 0.065534f, 0.057353f, 0.049172f, 0.040992f, 0.023866f, 0.018838f, 0.013810f, 0.008783f, 0.003755f, -0.001273f, -0.006301f, 0.023578f, 0.017712f, 0.011846f, 0.005980f, 0.000114f, -0.005752f, 0.014509f, 0.011466f, 0.008422f, 0.005378f, 0.002334f, -0.000710f, -0.003754f, 0.014345f, 0.010794f, 0.007243f, 0.003692f, 0.000140f, -0.003411f, 0.005152f, 0.004093f, 0.003033f, 0.001973f, 0.000914f, -0.000146f, -0.001206f, 0.005112f, 0.003876f, 0.002639f, 0.001403f, 0.000167f, -0.001070f, -0.004204f, -0.003280f, -0.002356f, -0.001431f, -0.000507f, 0.000418f, 0.001342f, -0.004121f, -0.003042f, -0.001964f, -0.000885f, 0.000193f, 0.001272f, 0.006798f, 0.009707f, 0.012616f, 0.015524f, 0.018433f, 0.021341f, 0.024250f, 0.007006f, 0.010399f, 0.013793f, 0.017186f, 0.020579f, 0.023973f, 0.011330f, 0.016223f, 0.021116f, 0.026008f, 0.030901f, 0.035794f, 0.040686f, 0.011662f, 0.017370f, 0.023078f, 0.028786f, 0.034494f, 0.040203f, 0.015862f, 0.022739f, 0.029616f, 0.036493f, 0.043369f, 0.050246f, 0.057123f, 0.016318f, 0.024341f, 0.032364f, 0.040387f, 0.048410f, 0.056433f, 0.020394f, 0.029255f, 0.038116f, 0.046977f, 0.055838f, 0.064699f, 0.073560f, 0.020974f, 0.031312f, 0.041649f, 0.051987f, 0.062325f, 0.072663f, 0.100842f, 0.111687f, 0.122532f, 0.133377f, 0.144222f, 0.155067f, 0.165912f, 0.101545f, 0.114198f, 0.126850f, 0.139503f, 0.152155f, 0.164808f, 0.119262f, 0.132092f, 0.144921f, 0.157750f, 0.170579f, 0.183408f, 0.196237f, 0.120090f, 0.135057f, 0.150025f, 0.164992f, 0.179960f, 0.194927f, 0.137683f, 0.152496f, 0.167310f, 0.182123f, 0.196936f, 0.211750f, 0.226563f, 0.138635f, 0.155917f, 0.173199f, 0.190481f, 0.207764f, 0.225046f, 0.156104f, 0.172901f, 0.189699f, 0.206496f, 0.223294f, 0.240091f, 0.256889f, 0.157180f, 0.176777f, 0.196374f, 0.215971f, 0.235568f, 0.255165f, 0.305996f, 0.324777f, 0.343559f, 0.362340f, 0.381122f, 0.399904f, 0.418685f, 0.307195f, 0.329107f, 0.351019f, 0.372931f, 0.394843f, 0.416755f, 0.338305f, 0.359071f, 0.379837f, 0.400603f, 0.421368f, 0.442134f, 0.462900f, 0.339629f, 0.363856f, 0.388082f, 0.412309f, 0.436536f, 0.460762f, 0.370615f, 0.393365f, 0.416115f, 0.438865f, 0.461614f, 0.484364f, 0.507114f, 0.372063f, 0.398604f, 0.425146f, 0.451687f, 0.478229f, 0.504770f, 0.402925f, 0.427659f, 0.452393f, 0.477127f, 0.501861f, 0.526595f, 0.551329f, 0.404497f, 0.433353f, 0.462209f, 0.491065f, 0.519922f, 0.548778f}; + std::vector present_key = {0.308528f, 0.942185f, 0.888265f, 0.860311f, 0.653000f, 0.344289f, 0.548849f, 0.815225f, 0.098610f, 0.801075f, 0.041180f, 0.816421f, 0.807564f, 0.051007f, 0.627161f, 0.502453f, 0.169820f, 0.148379f, 0.773259f, 0.567693f, 0.982999f, 0.982248f, 0.992667f, 0.118616f, 0.938256f, 0.244570f, 0.458212f, 0.757407f, 0.203621f, 0.566312f, 0.185817f, 0.104736f, 0.116559f, 0.357639f, 0.004655f, 0.424854f, 0.664197f, 0.401688f, 0.085795f, 0.062689f, 0.278117f, 0.169313f, 0.965095f, 0.151230f, 0.805462f, 0.586108f, 0.569287f, 0.512081f, 0.971763f, 0.363845f, 0.787916f, 0.555294f, 0.395634f, 0.955466f, 0.598316f, 0.118917f, 0.417539f, 0.781582f, 0.693747f, 0.916340f, 0.259377f, 0.758194f, 0.459875f, 0.573610f, 0.955047f, 0.979286f, 0.861591f, 0.359097f, 0.887701f, 0.638609f, 0.429997f, 0.035743f, 0.770128f, 0.502106f, 0.786188f, 0.748023f, 0.793567f, 0.300651f, 0.800799f, 0.548846f, 0.473326f, 0.675126f, 0.021359f, 0.102317f, 0.292177f, 0.982990f, 0.139746f, 0.330596f, 0.051053f, 0.331269f, 0.320326f, 0.946807f, 0.845154f, 0.382764f, 0.024769f, 0.831031f, 0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.660536f, 0.152364f, 0.996071f, 0.100233f, 0.867115f, 0.294266f, 0.435353f, 0.795457f, 0.677508f, 0.937864f, 0.621140f, 0.097810f, 0.884360f, 0.769156f, 0.711870f, 0.053734f, 0.396223f, 0.167436f, 0.821904f, 0.700529f, 0.883078f, 0.966575f, 0.774748f, 0.994233f, 0.614770f, 0.037130f, 0.014252f, 0.342104f, 0.823472f, 0.866135f, 0.960813f, 0.065121f, 0.044571f, 0.913284f, 0.305047f, 0.557987f, 0.982445f, 0.400449f, 0.665871f, 0.400880f, 0.768195f, 0.527715f, 0.237523f, 0.271306f, 0.258059f, 0.532320f, 0.703189f, 0.949280f, 0.694087f, 0.781193f, 0.168926f, 0.374063f, 0.413780f, 0.686380f, 0.295892f, 0.303292f, 0.355889f, 0.810302f, 0.577590f, 0.075277f, 0.078246f, 0.371287f, 0.766591f, 0.688683f, 0.707982f, 0.767210f, 0.287153f, 0.548256f, 0.543353f, 0.739632f, 0.956871f, 0.277990f, 0.793282f, 0.659971f, 0.580238f, 0.774880f, 0.944032f, 0.036691f, 0.147400f, 0.756287f, 0.083791f, 0.516124f, 0.219861f, 0.274296f, 0.701840f, 0.030193f, 0.873319f, 0.444479f, 0.502393f, 0.540048f, 0.645544f, 0.344857f, 0.101107f, 0.318379f, 0.168142f, 0.556133f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.318029f, 0.958067f, 0.965734f, 0.620126f, 0.617497f, 0.985379f, 0.887283f, 0.765070f, 0.313591f, 0.365539f, 0.201267f, 0.487148f, 0.990369f, 0.912151f, 0.118349f, 0.025190f, 0.898638f, 0.537170f, 0.200190f, 0.673653f, 0.644223f, 0.122086f, 0.259600f, 0.060078f, 0.209860f, 0.132306f, 0.193236f, 0.685467f, 0.049500f, 0.101855f, 0.134174f, 0.316541f, 0.298750f, 0.255064f, 0.750537f, 0.998023f, 0.533978f, 0.944203f, 0.396610f, 0.106682f, 0.408774f, 0.296128f, 0.493407f, 0.657044f, 0.461050f, 0.935161f, 0.884765f, 0.701978f, 0.489685f, 0.131687f, 0.397014f, 0.704402f, 0.284886f, 0.103988f, 0.907898f, 0.709051f, 0.615276f, 0.792499f, 0.835646f, 0.483459f, 0.881188f, 0.916419f, 0.271551f, 0.607545f, 0.526584f, 0.537946f, 0.937663f, 0.305189f, 0.983434f, 0.902131f, 0.458723f, 0.817453f, 0.769047f, 0.677895f, 0.319834f, 0.196451f, 0.671528f, 0.842973f, 0.016253f, 0.642803f, 0.442873f, 0.898088f, 0.321473f, 0.474185f, 0.514767f, 0.140440f, 0.712892f, 0.830476f, 0.057909f, 0.291389f, 0.038045f, 0.956544f, 0.667169f, 0.964200f, 0.531494f, 0.802069f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.374414f, 0.353819f, 0.378268f, 0.657862f, 0.359453f, 0.900367f, 0.983275f, 0.030427f, 0.193623f, 0.112250f, 0.042364f, 0.227741f, 0.446793f, 0.836990f, 0.221824f, 0.493945f, 0.929619f, 0.667215f, 0.798079f, 0.550994f, 0.980466f, 0.588662f, 0.045511f, 0.197983f, 0.404774f, 0.601277f, 0.771931f, 0.413086f, 0.710058f, 0.789869f, 0.317260f, 0.979270f, 0.649656f, 0.880998f, 0.555938f, 0.741603f, 0.770544f, 0.908248f, 0.150350f, 0.558283f, 0.428379f, 0.923159f, 0.105095f, 0.982574f, 0.875451f, 0.073826f, 0.490966f, 0.717560f, 0.738152f, 0.906494f, 0.799865f, 0.310930f, 0.498435f, 0.701786f, 0.138437f, 0.193991f, 0.481042f, 0.298246f, 0.862559f, 0.586277f, 0.348665f, 0.848833f, 0.804878f, 0.998355f, 0.847308f, 0.414457f, 0.127499f, 0.840641f, 0.059758f, 0.350271f, 0.919738f, 0.960766f, 0.640565f, 0.688648f, 0.042454f, 0.514480f, 0.546868f, 0.340101f, 0.068597f, 0.228908f, 0.357984f, 0.435142f, 0.590927f, 0.722392f, 0.317632f, 0.328954f, 0.019692f, 0.040875f, 0.257822f, 0.740245f, 0.628314f, 0.769789f, 0.768919f, 0.856567f, 0.720319f, 0.979011f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.898825f, 0.586717f, 0.588158f, 0.034267f, 0.998527f, 0.131576f, 0.740347f, 0.821015f, 0.373055f, 0.196852f, 0.098760f, 0.748606f, 0.452654f, 0.713718f, 0.915408f, 0.146584f, 0.919171f, 0.411626f, 0.305267f, 0.943062f, 0.990652f, 0.198892f, 0.656838f, 0.106495f, 0.650914f, 0.827313f, 0.684499f, 0.417333f, 0.383066f, 0.393122f, 0.589712f, 0.881567f, 0.929066f, 0.053530f, 0.181622f, 0.112224f, 0.193335f, 0.346608f, 0.506532f, 0.629461f, 0.732142f, 0.890112f, 0.989088f, 0.662856f, 0.845365f, 0.778039f, 0.307532f, 0.875692f, 0.042763f, 0.000367f, 0.273733f, 0.462098f, 0.638363f, 0.101770f, 0.673010f, 0.801816f, 0.185313f, 0.415125f, 0.519985f, 0.451807f, 0.799830f, 0.960522f, 0.798953f, 0.077993f, 0.804936f, 0.066596f, 0.235970f, 0.153097f, 0.197519f, 0.528315f, 0.671690f, 0.470321f, 0.959696f, 0.240292f, 0.763140f, 0.870182f, 0.562066f, 0.456223f, 0.596184f, 0.428810f, 0.555194f, 0.416934f, 0.400470f, 0.695346f, 0.092851f, 0.166542f, 0.851198f, 0.771077f, 0.281454f, 0.377269f, 0.926027f, 0.818077f, 0.614346f, 0.221490f, 0.044252f, 0.431258f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.672627f, 0.828480f, 0.852689f, 0.032776f, 0.244157f, 0.339095f, 0.188732f, 0.802975f, 0.767466f, 0.516833f, 0.982926f, 0.144059f, 0.899652f, 0.116463f, 0.163182f, 0.696219f, 0.109570f, 0.565845f, 0.420234f, 0.728474f, 0.900675f, 0.769872f, 0.849690f, 0.032945f, 0.310196f, 0.515433f, 0.415953f, 0.231255f, 0.307874f, 0.945431f, 0.294181f, 0.353904f, 0.003710f, 0.845078f, 0.154841f, 0.204144f, 0.255265f, 0.884622f, 0.206451f, 0.797526f, 0.808049f, 0.927021f, 0.115561f, 0.217279f, 0.742898f, 0.196001f, 0.286330f, 0.166742f, 0.172697f, 0.481553f, 0.109683f, 0.321698f, 0.426594f, 0.024548f, 0.388333f, 0.094122f, 0.493579f, 0.825738f, 0.818422f, 0.080449f, 0.601228f, 0.834586f, 0.237973f, 0.761927f, 0.890764f, 0.806124f, 0.107301f, 0.009060f, 0.191724f, 0.270477f, 0.616183f, 0.384273f, 0.703407f, 0.353075f, 0.154425f, 0.312690f, 0.884324f, 0.958532f, 0.207513f, 0.788468f, 0.273349f, 0.887132f, 0.165546f, 0.665960f, 0.084211f, 0.973893f, 0.700633f, 0.841816f, 0.566669f, 0.476801f, 0.621882f, 0.528742f, 0.469384f, 0.759450f, 0.178201f, 0.171172f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f}; + // {2, 3, 18, 8} + std::vector present_value = {0.431843f, 0.320748f, 0.074125f, 0.844471f, 0.771603f, 0.543921f, 0.979325f, 0.072600f, 0.766669f, 0.266370f, 0.368599f, 0.219279f, 0.789038f, 0.144240f, 0.840017f, 0.661578f, 0.059023f, 0.810982f, 0.627756f, 0.904982f, 0.748722f, 0.561121f, 0.836547f, 0.278050f, 0.546950f, 0.293617f, 0.968204f, 0.226196f, 0.015738f, 0.325855f, 0.502509f, 0.028363f, 0.559248f, 0.874283f, 0.704732f, 0.622968f, 0.955962f, 0.958279f, 0.824266f, 0.607742f, 0.487765f, 0.013316f, 0.606262f, 0.989088f, 0.818101f, 0.340605f, 0.152047f, 0.784059f, 0.743938f, 0.967047f, 0.874842f, 0.555663f, 0.101284f, 0.483501f, 0.313695f, 0.512408f, 0.301702f, 0.861823f, 0.844327f, 0.315465f, 0.599581f, 0.430181f, 0.909093f, 0.187361f, 0.697728f, 0.970375f, 0.175276f, 0.201966f, 0.693723f, 0.779154f, 0.490549f, 0.609686f, 0.212682f, 0.476614f, 0.112072f, 0.321422f, 0.284780f, 0.444625f, 0.930126f, 0.181268f, 0.401388f, 0.615597f, 0.946557f, 0.133148f, 0.917877f, 0.081054f, 0.480741f, 0.454590f, 0.209603f, 0.347460f, 0.454165f, 0.865211f, 0.955064f, 0.518926f, 0.870100f, 0.608172f, 0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.349087f, 0.194194f, 0.413135f, 0.522824f, 0.044443f, 0.145841f, 0.600184f, 0.225002f, 0.837326f, 0.326942f, 0.104834f, 0.083531f, 0.937123f, 0.118020f, 0.140910f, 0.862666f, 0.254288f, 0.665951f, 0.816726f, 0.607181f, 0.957489f, 0.708883f, 0.112752f, 0.558410f, 0.718186f, 0.801957f, 0.026321f, 0.718879f, 0.825681f, 0.746834f, 0.512349f, 0.458021f, 0.549419f, 0.704644f, 0.922914f, 0.617035f, 0.887834f, 0.701257f, 0.068336f, 0.500828f, 0.286486f, 0.285175f, 0.355928f, 0.314733f, 0.578610f, 0.683601f, 0.268749f, 0.129763f, 0.058809f, 0.575753f, 0.186130f, 0.009248f, 0.927753f, 0.537140f, 0.092448f, 0.842921f, 0.983203f, 0.448601f, 0.042490f, 0.117546f, 0.381654f, 0.885523f, 0.148039f, 0.823990f, 0.014976f, 0.457389f, 0.644397f, 0.060379f, 0.614763f, 0.944404f, 0.160260f, 0.729611f, 0.609094f, 0.185116f, 0.006203f, 0.009284f, 0.532092f, 0.942779f, 0.644299f, 0.714300f, 0.493865f, 0.581889f, 0.126368f, 0.876821f, 0.760793f, 0.998199f, 0.297723f, 0.227018f, 0.125162f, 0.964210f, 0.780885f, 0.166325f, 0.552686f, 0.413768f, 0.151486f, 0.162073f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f, 0.963470f, 0.304964f, 0.941439f, 0.075611f, 0.460803f, 0.129619f, 0.004787f, 0.553766f, 0.113894f, 0.722025f, 0.698116f, 0.176333f, 0.941742f, 0.721043f, 0.297970f, 0.709234f, 0.731930f, 0.342226f, 0.375589f, 0.359107f, 0.616618f, 0.900410f, 0.173193f, 0.875200f, 0.027653f, 0.660339f, 0.414439f, 0.791282f, 0.721198f, 0.480108f, 0.643864f, 0.501773f, 0.811518f, 0.476084f, 0.523156f, 0.250521f, 0.605043f, 0.302905f, 0.577284f, 0.169678f, 0.159469f, 0.417030f, 0.426820f, 0.268109f, 0.131597f, 0.039211f, 0.025232f, 0.271550f, 0.461853f, 0.726243f, 0.474872f, 0.904051f, 0.035220f, 0.180661f, 0.338515f, 0.577496f, 0.852736f, 0.350202f, 0.267989f, 0.061889f, 0.821303f, 0.379666f, 0.571550f, 0.983555f, 0.001595f, 0.145450f, 0.779111f, 0.805128f, 0.769247f, 0.536999f, 0.978857f, 0.396185f, 0.601944f, 0.063369f, 0.409857f, 0.722500f, 0.238739f, 0.943828f, 0.686783f, 0.287575f, 0.768999f, 0.083165f, 0.974774f, 0.049285f, 0.933456f, 0.252854f, 0.757824f, 0.000074f, 0.254240f, 0.749101f, 0.532336f, 0.114952f, 0.393630f, 0.375549f, 0.568162f, 0.667977f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.840830f, 0.497231f, 0.392022f, 0.143977f, 0.804823f, 0.713370f, 0.408677f, 0.518432f, 0.665183f, 0.164806f, 0.027198f, 0.317504f, 0.595585f, 0.486606f, 0.692555f, 0.819690f, 0.488442f, 0.134267f, 0.850628f, 0.574990f, 0.739937f, 0.704665f, 0.968212f, 0.295307f, 0.705307f, 0.365676f, 0.395411f, 0.230595f, 0.344010f, 0.948297f, 0.292571f, 0.245991f, 0.583138f, 0.258036f, 0.473386f, 0.834176f, 0.230400f, 0.426691f, 0.610490f, 0.545629f, 0.974723f, 0.680370f, 0.739946f, 0.966956f, 0.414438f, 0.355380f, 0.043862f, 0.184204f, 0.237190f, 0.183504f, 0.754784f, 0.535883f, 0.667634f, 0.820462f, 0.230774f, 0.325924f, 0.708360f, 0.392759f, 0.029271f, 0.434955f, 0.908273f, 0.409021f, 0.332249f, 0.989525f, 0.644416f, 0.365998f, 0.102020f, 0.787849f, 0.708075f, 0.921916f, 0.217276f, 0.114924f, 0.724073f, 0.203396f, 0.176104f, 0.319807f, 0.816825f, 0.539537f, 0.045850f, 0.463895f, 0.683980f, 0.538368f, 0.572450f, 0.224777f, 0.847739f, 0.561399f, 0.713246f, 0.981864f, 0.428199f, 0.881067f, 0.007281f, 0.033407f, 0.590280f, 0.311449f, 0.248277f, 0.277935f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.318403f, 0.728948f, 0.569196f, 0.789036f, 0.830197f, 0.842935f, 0.414644f, 0.421273f, 0.926266f, 0.661764f, 0.080467f, 0.542187f, 0.356007f, 0.987435f, 0.013655f, 0.612181f, 0.723623f, 0.288907f, 0.973642f, 0.859537f, 0.915653f, 0.019232f, 0.569872f, 0.294650f, 0.849029f, 0.632850f, 0.538877f, 0.114588f, 0.540223f, 0.631904f, 0.955912f, 0.585051f, 0.967401f, 0.961606f, 0.650200f, 0.505908f, 0.466022f, 0.890379f, 0.028257f, 0.113808f, 0.102072f, 0.756935f, 0.339651f, 0.637969f, 0.603783f, 0.385828f, 0.531568f, 0.645139f, 0.940950f, 0.575634f, 0.614367f, 0.067856f, 0.952216f, 0.528082f, 0.801273f, 0.050291f, 0.420910f, 0.256975f, 0.266976f, 0.791454f, 0.623867f, 0.439745f, 0.010586f, 0.964928f, 0.962023f, 0.217552f, 0.041346f, 0.530199f, 0.951411f, 0.910396f, 0.584663f, 0.303549f, 0.329961f, 0.897914f, 0.491784f, 0.131116f, 0.248425f, 0.276795f, 0.123547f, 0.463044f, 0.916051f, 0.668783f, 0.072474f, 0.005495f, 0.276248f, 0.362693f, 0.776750f, 0.967006f, 0.387567f, 0.686690f, 0.994902f, 0.745667f, 0.636190f, 0.078075f, 0.323215f, 0.913392f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.201005f, 0.843590f, 0.696324f, 0.366324f, 0.529174f, 0.542806f, 0.714054f, 0.516556f, 0.133076f, 0.773455f, 0.406273f, 0.963094f, 0.283514f, 0.263079f, 0.333507f, 0.572317f, 0.894870f, 0.176282f, 0.279679f, 0.581680f, 0.454334f, 0.447323f, 0.820734f, 0.923878f, 0.481307f, 0.687352f, 0.801059f, 0.518366f, 0.294316f, 0.638085f, 0.585109f, 0.901563f, 0.052407f, 0.910131f, 0.534432f, 0.015676f, 0.344702f, 0.724334f, 0.488433f, 0.980159f, 0.422610f, 0.326635f, 0.821672f, 0.547907f, 0.682327f, 0.805702f, 0.671428f, 0.422408f, 0.124796f, 0.580248f, 0.897433f, 0.418892f, 0.910725f, 0.503528f, 0.620842f, 0.832989f, 0.564597f, 0.090969f, 0.980979f, 0.245849f, 0.710505f, 0.505113f, 0.478773f, 0.243941f, 0.722151f, 0.112788f, 0.990453f, 0.845374f, 0.534509f, 0.424553f, 0.286465f, 0.501591f, 0.879417f, 0.275006f, 0.500537f, 0.234550f, 0.337149f, 0.190261f, 0.990539f, 0.571497f, 0.732815f, 0.098250f, 0.366118f, 0.892640f, 0.084438f, 0.165483f, 0.625418f, 0.622789f, 0.838227f, 0.935493f, 0.141986f, 0.259374f, 0.427461f, 0.000903f, 0.069814f, 0.226491f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f}; + // {2, 3, 4, 18} + constexpr float inff = std::numeric_limits::infinity(); + std::vector qk_matmul = {2.137658f, 1.567682f, 1.582827f, 0.953936f, 0.636597f, 1.001645f, 1.885707f, 1.361086f, 1.495408f, 1.566455f, 1.459078f, 1.668413f, 0.904174f, -inff, -inff, -inff, -inff, -inff, 1.229267f, 0.591855f, 1.372683f, 0.964445f, 1.006092f, 1.046331f, 1.712052f, 1.060710f, 2.141520f, 1.917742f, 1.063752f, 0.892409f, 0.884336f, 0.881352f, -inff, -inff, -inff, -inff, 2.235662f, 1.742821f, 2.198921f, 1.079357f, 1.510221f, 1.812315f, 1.396341f, 1.864746f, 1.498768f, 2.115730f, 0.844762f, 1.323617f, 1.096593f, 1.033003f, 1.868677f, -inff, -inff, -inff, 1.429269f, 0.876355f, 0.928405f, 1.469794f, 0.649940f, 1.435654f, 1.452830f, 1.053687f, 1.338220f, 0.966775f, 1.237266f, 1.488850f, 1.438267f, 0.931250f, 1.633272f, 0.944889f, -inff, -inff, 1.172613f, 1.105815f, 1.263303f, 1.702161f, 1.406517f, 1.808470f, 1.496128f, 1.169961f, 1.428707f, 1.393064f, 1.624670f, 1.287919f, 0.674733f, -inff, -inff, -inff, -inff, -inff, 0.838456f, 1.191558f, 1.771291f, 1.491907f, 0.911088f, 0.865799f, 1.154893f, 1.472593f, 0.826140f, 0.896018f, 1.281853f, 0.942941f, 1.470656f, 0.816028f, -inff, -inff, -inff, -inff, 1.133820f, 1.086309f, 1.712385f, 1.254675f, 1.427773f, 0.748848f, 1.056134f, 1.187805f, 1.419181f, 1.140224f, 1.269629f, 1.135934f, 0.694738f, 1.528325f, 0.959286f, -inff, -inff, -inff, 1.160321f, 1.097000f, 1.485019f, 1.111147f, 0.836961f, 0.948765f, 1.234762f, 0.835082f, 0.833382f, 0.589928f, 1.266538f, 1.303439f, 0.622733f, 0.837537f, 0.605730f, 0.730216f, -inff, -inff, 2.078597f, 0.610472f, 1.371772f, 0.794857f, 1.018924f, 1.165257f, 1.466839f, 1.206415f, 1.662507f, 1.098436f, 1.283408f, 1.533854f, 1.247966f, -inff, -inff, -inff, -inff, -inff, 1.707491f, 0.439978f, 0.919238f, 0.297115f, 0.982817f, 1.370520f, 0.766707f, 0.938981f, 1.095468f, 1.442393f, 0.742909f, 0.529869f, 0.628822f, 1.353301f, -inff, -inff, -inff, -inff, 1.483284f, 1.334536f, 0.757364f, 1.243801f, 0.767143f, 0.919318f, 0.693929f, 1.000990f, 1.107699f, 1.001247f, 1.434079f, 1.522769f, 0.696104f, 1.336034f, 0.501240f, -inff, -inff, -inff, 1.535892f, 1.342303f, 0.701559f, 1.211220f, 1.510985f, 0.961962f, 1.471503f, 1.440467f, 1.835586f, 0.947043f, 1.254547f, 1.009386f, 0.842613f, 1.508191f, 1.233544f, 1.280385f, -inff, -inff, 1.552432f, 0.958768f, 1.676495f, 1.810273f, 1.019336f, 1.487615f, 0.695035f, 1.391893f, 1.060641f, 0.917107f, 1.115109f, 1.128137f, 0.986429f, -inff, -inff, -inff, -inff, -inff, 1.289288f, 1.303667f, 0.882238f, 1.948027f, 1.580638f, 0.863439f, 1.059965f, 2.095325f, 1.493638f, 0.654104f, 0.828719f, 1.673449f, 0.479778f, 1.149678f, -inff, -inff, -inff, -inff, 1.177682f, 1.225590f, 1.735621f, 2.114078f, 1.905758f, 1.835981f, 1.432170f, 1.444457f, 2.016032f, 0.762211f, 1.059737f, 1.378216f, 1.564930f, 1.950097f, 1.598798f, -inff, -inff, -inff, 0.820477f, 0.962096f, 1.188223f, 1.264395f, 1.676953f, 1.487113f, 0.962162f, 1.377522f, 1.370079f, 1.450785f, 1.131087f, 1.962317f, 0.764849f, 0.777860f, 1.194763f, 1.030136f, -inff, -inff, 1.096708f, 1.345589f, 1.404595f, 1.370459f, 1.263369f, 1.364863f, 0.489623f, 0.596189f, 1.079480f, 0.915348f, 0.770954f, 1.548047f, 1.519504f, -inff, -inff, -inff, -inff, -inff, 1.856943f, 0.790590f, 1.235241f, 2.061177f, 1.282346f, 1.896653f, 1.112410f, 1.622862f, 0.780625f, 1.990919f, 1.693934f, 1.466544f, 1.026297f, 1.323339f, -inff, -inff, -inff, -inff, 1.778816f, 1.746915f, 1.169870f, 1.847628f, 0.729303f, 2.421048f, 1.266061f, 1.481203f, 1.016384f, 2.038725f, 1.132054f, 1.669076f, 1.958931f, 1.654780f, 1.644111f, -inff, -inff, -inff, 0.856287f, 1.124803f, 1.216201f, 0.831110f, 0.761234f, 1.204141f, 0.994307f, 0.832859f, 1.294077f, 1.566637f, 1.102631f, 1.472731f, 1.569911f, 0.779225f, 1.536189f, 1.277889f, -inff, -inff, 0.944230f, 1.585174f, 1.001532f, 0.973579f, 1.652668f, 1.112330f, 1.052878f, 1.326390f, 1.526319f, 1.790060f, 1.219317f, 1.742865f, 0.871467f, -inff, -inff, -inff, -inff, -inff, 0.794245f, 1.084904f, 0.813691f, 1.037344f, 0.254175f, 1.071614f, 0.477497f, 0.773591f, 1.317670f, 1.382451f, 0.759806f, 1.228428f, 0.583565f, 1.274037f, -inff, -inff, -inff, -inff, 0.865060f, 0.697643f, 1.300273f, 1.064195f, 1.435744f, 1.516307f, 0.626589f, 1.255387f, 1.115037f, 1.202643f, 1.789729f, 1.328769f, 1.046150f, 1.149905f, 1.696396f, -inff, -inff, -inff, 1.421552f, 1.324626f, 1.029005f, 0.960238f, 1.215132f, 1.450928f, 1.351898f, 1.718175f, 1.502146f, 1.736591f, 1.019685f, 1.130950f, 1.097223f, 1.330517f, 1.675029f, 1.069868f, -inff, -inff}; ASSERT_EQ(y.size(), batch_size * q_num_heads * q_sequence_length * v_head_size); ASSERT_EQ(present_key.size(), batch_size * kv_num_heads * (past_sequence_length + kv_sequence_length) * head_size); ASSERT_EQ(present_value.size(), batch_size * kv_num_heads * (past_sequence_length + kv_sequence_length) * v_head_size); @@ -1196,7 +1230,7 @@ TEST(AttentionTest, Attention4DWithMask4DPastAndPresentQkMatmul) { RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, q, k, v, m, std::initializer_list(), past_key, past_value, - -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + 1, 1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, present_key, present_value, qk_matmul, false, true, true // disable_cpu, disable_cuda, disable_dml ); diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_ep_static_kernels.cc b/onnxruntime/test/providers/internal_testing/internal_testing_ep_static_kernels.cc index ea6b3f148979f..b0ee078335308 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_ep_static_kernels.cc +++ b/onnxruntime/test/providers/internal_testing/internal_testing_ep_static_kernels.cc @@ -7,14 +7,12 @@ #include "core/framework/utils.h" #include "core/providers/cpu/nn/conv_attributes.h" #include "core/providers/utils.h" +#include "test/providers/internal_testing/internal_testing_execution_provider.h" namespace onnxruntime { namespace internal_testing_ep { -// can't use 'utils::kInternalTestingExecutionProvider' in the macro so redefine here to a name without '::' -constexpr const char* internal_testing_ep = utils::kInternalTestingExecutionProvider; - -ONNX_OPERATOR_KERNEL_EX(Conv, kMSInternalNHWCDomain, 11, internal_testing_ep, +ONNX_OPERATOR_KERNEL_EX(Conv, kMSInternalNHWCDomain, 11, kInternalTestingExecutionProvider, KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), Conv); diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc index 390329c5cae7a..934916ea862e9 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc +++ b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc @@ -26,17 +26,15 @@ namespace internal_testing_ep { // NHWC Conv requires contrib ops #if !defined(ORT_MINIMAL_BUILD) && !defined(DISABLE_CONTRIB_OPS) -// the 'utils::' breaks the kernel registration macros -constexpr const char* internal_testing_ep = utils::kInternalTestingExecutionProvider; -class ONNX_OPERATOR_KERNEL_CLASS_NAME(internal_testing_ep, kMSInternalNHWCDomain, 11, Conv); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kInternalTestingExecutionProvider, kMSInternalNHWCDomain, 11, Conv); // register static kernels we have implementations for static std::unique_ptr RegisterKernels() { auto kernel_registry = std::make_unique(); ORT_THROW_IF_ERROR(kernel_registry->Register( - BuildKernelCreateInfo())); return kernel_registry; @@ -68,7 +66,7 @@ void RegisterDummyStaticKernel(KernelRegistry& registry, const Node& node) { builder.SetName(node.OpType()) .SetDomain(node.Domain()) .SinceVersion(node.SinceVersion()) - .Provider(internal_testing_ep); + .Provider(kInternalTestingExecutionProvider); ORT_THROW_IF_ERROR(registry.Register(builder, DummyCreateKernel)); } @@ -85,7 +83,7 @@ constexpr const char* INTERNAL_TESTING_EP = "InternalTestingEP"; InternalTestingExecutionProvider::InternalTestingExecutionProvider(const std::unordered_set& ops, const std::unordered_set& stop_ops, DataLayout preferred_layout) - : IExecutionProvider{utils::kInternalTestingExecutionProvider}, + : IExecutionProvider{kInternalTestingExecutionProvider}, ep_name_{INTERNAL_TESTING_EP}, ops_{ops}, stop_ops_{stop_ops}, @@ -221,7 +219,7 @@ InternalTestingExecutionProvider::GetCapability(const onnxruntime::GraphViewer& auto compile_capabilities = utils::CreateSupportedPartitions(graph_viewer, supported_compiled_nodes, stop_ops_, generate_metadef_name, ep_name_, - onnxruntime::utils::kInternalTestingExecutionProvider, + kInternalTestingExecutionProvider, /*QDQ NodeUnit map*/ nullptr, debug_output_); diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h index 0caa0febc2796..8832265798798 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h +++ b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h @@ -9,6 +9,9 @@ namespace onnxruntime { namespace internal_testing_ep { +// Provider type of `InternalTestingExecutionProvider`, an EP used for internal testing. +constexpr const char* kInternalTestingExecutionProvider = "InternalTestingExecutionProvider"; + class InternalTestingExecutionProvider : public IExecutionProvider { public: InternalTestingExecutionProvider(const std::unordered_set& ops, diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_partitioning_tests.cc b/onnxruntime/test/providers/internal_testing/internal_testing_partitioning_tests.cc index d58db5178032d..c085d1acd10c0 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_partitioning_tests.cc +++ b/onnxruntime/test/providers/internal_testing/internal_testing_partitioning_tests.cc @@ -57,7 +57,7 @@ auto RunTest(const std::string& op, const ORTCHAR_T* model_path) { for (const auto& node : graph.Nodes()) { EXPECT_EQ(supported_ops.count(node.OpType()), size_t(0)) << "Nodes with supported op types should have been replaced. Node with type " << node.OpType() << " was not."; - if (node.GetExecutionProviderType() == utils::kInternalTestingExecutionProvider) { + if (node.GetExecutionProviderType() == kInternalTestingExecutionProvider) { EXPECT_STATUS_OK(func_mgr.GetFuncs(node.Name(), compute_func)); EXPECT_NE(compute_func, nullptr); ++num_partitions; @@ -116,7 +116,7 @@ TEST(InternalTestingEP, TestDependenciesCorrectlyHandled) { for (const auto& node : graph.Nodes()) { EXPECT_EQ(supported_ops.count(node.OpType()), size_t(0)) << "Nodes with supported op types should have been replaced. Node with type " << node.OpType() << " was not."; - if (node.GetExecutionProviderType() == utils::kInternalTestingExecutionProvider) { + if (node.GetExecutionProviderType() == kInternalTestingExecutionProvider) { EXPECT_STATUS_OK(func_mgr.GetFuncs(node.Name(), compute_func)); EXPECT_NE(compute_func, nullptr); ++num_partitions; @@ -227,7 +227,7 @@ static void TestNnapiPartitioning(const std::string& test_name, const std::strin std::string unsupported_op_str; for (const Node& node : graph.Nodes()) { - if (node.GetExecutionProviderType() != utils::kInternalTestingExecutionProvider && + if (node.GetExecutionProviderType() != kInternalTestingExecutionProvider && ops.count(node.OpType()) == 0) { auto entry = unsupported_ops.find(node.OpType()); if (entry != unsupported_ops.end()) { @@ -288,12 +288,12 @@ static void TestNnapiPartitioning(const std::string& test_name, const std::strin << "Nodes with supported op types should have been replaced. Node with type " << node.OpType() << " was not."; } else { - EXPECT_NE(node.GetExecutionProviderType(), utils::kInternalTestingExecutionProvider) + EXPECT_NE(node.GetExecutionProviderType(), kInternalTestingExecutionProvider) << "Node is downstream from a 'stop at' node and should not have been taken. Node:" << node.Name(); } - if (node.GetExecutionProviderType() == utils::kInternalTestingExecutionProvider) { + if (node.GetExecutionProviderType() == kInternalTestingExecutionProvider) { EXPECT_STATUS_OK(func_mgr.GetFuncs(node.Name(), compute_func)); EXPECT_NE(compute_func, nullptr); ++stats.num_compiled_nodes; diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc b/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc index 275f29fdd9073..94e60739c3ccf 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc +++ b/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc @@ -334,7 +334,7 @@ TEST(InternalTestingEP, TestLoadOrtModelWithReducedOpCoverage) { for (const auto& node : graph.Nodes()) { EXPECT_EQ(supported_ops.count(node.OpType()), size_t(0)) << "Nodes with supported op types should have been replaced. Node with type " << node.OpType() << " was not."; - if (node.GetExecutionProviderType() == utils::kInternalTestingExecutionProvider) { + if (node.GetExecutionProviderType() == kInternalTestingExecutionProvider) { EXPECT_STATUS_OK(func_mgr.GetFuncs(node.Name(), compute_func)); EXPECT_NE(compute_func, nullptr); EXPECT_THAT(node.OpType(), ::testing::StartsWith(expected_op_type_prefix)); @@ -353,7 +353,7 @@ static int CountAndValidateAssignedNodes(const Graph& current_graph, for (const auto& node : current_graph.Nodes()) { EXPECT_EQ(supported_ops.count(node.OpType()), size_t(0)) << "Nodes with supported op types should have been replaced. Node with type " << node.OpType() << " was not."; - if (node.GetExecutionProviderType() == utils::kInternalTestingExecutionProvider) { + if (node.GetExecutionProviderType() == kInternalTestingExecutionProvider) { const NodeComputeInfo* compute_func = nullptr; EXPECT_STATUS_OK(func_mgr.GetFuncs(node.Name(), compute_func)); EXPECT_NE(compute_func, nullptr); diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index 1c8cc6f78fe63..a2f1b9b56538b 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -2076,6 +2076,21 @@ TEST_F(QnnHTPBackendTests, QnnEpDynamicOptions) { } catch (const std::exception& e) { EXPECT_STREQ("Unsupported EP Dynamic Option", e.what()); } + + const char* const htp_perf_mode_type[] = {"ep.dynamic.qnn_htp_performance_mode"}; + const char* const eps_type[] = {"extreme_power_saver"}; + const char* const shp_type[] = {"sustained_high_performance"}; + session.SetEpDynamicOptions(htp_perf_mode_type, shp_type, 1); + ort_output = session.Run(Ort::RunOptions{}, input_names_c.data(), ort_inputs.data(), ort_inputs.size(), + output_names_c.data(), 1); + + session.SetEpDynamicOptions(htp_perf_mode_type, eps_type, 1); + ort_output = session.Run(Ort::RunOptions{}, input_names_c.data(), ort_inputs.data(), ort_inputs.size(), + output_names_c.data(), 1); + + session.SetEpDynamicOptions(htp_perf_mode_type, shp_type, 1); + ort_output = session.Run(Ort::RunOptions{}, input_names_c.data(), ort_inputs.data(), ort_inputs.size(), + output_names_c.data(), 1); } // Implementation of OrtOutStreamWriteFunc that writes the compiled model to a file. diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index 1820664e1d604..b85030b46e94d 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -689,15 +689,30 @@ def test_run_model_with_optional_sequence_input(self): def test_run_model(self): sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=available_providers) x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) - input_name = sess.get_inputs()[0].name - self.assertEqual(input_name, "X") - input_shape = sess.get_inputs()[0].shape - self.assertEqual(input_shape, [3, 2]) - output_name = sess.get_outputs()[0].name - self.assertEqual(output_name, "Y") - output_shape = sess.get_outputs()[0].shape - self.assertEqual(output_shape, [3, 2]) - res = sess.run([output_name], {input_name: x}) + + inputs = sess.get_inputs() + self.assertEqual(len(inputs), 1) + self.assertEqual(inputs[0].name, "X") + self.assertEqual(inputs[0].shape, [3, 2]) + + input_meminfos = sess.get_input_memory_infos() + self.assertEqual(len(input_meminfos), 1) + self.assertIsNotNone(input_meminfos[0]) + + input_epdevices = sess.get_input_epdevices() + # The entry my be None (null) but it should be present + self.assertEqual(len(input_epdevices), 1) + + outputs = sess.get_outputs() + self.assertEqual(len(outputs), 1) + self.assertEqual(outputs[0].name, "Y") + self.assertEqual(outputs[0].shape, [3, 2]) + + output_meminfos = sess.get_output_memory_infos() + self.assertEqual(len(output_meminfos), 1) + self.assertIsNotNone(output_meminfos[0]) + + res = sess.run([outputs[0].name], {inputs[0].name: x}) output_expected = np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32) np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) @@ -1584,6 +1599,44 @@ def test_run_model_with_cuda_copy_stream(self): for _iteration in range(100000): session.run(output_names=["output"], input_feed={"shape": shape}) + def test_ort_device(self): + cpu_device = onnxrt.OrtDevice.make("cpu", 0) + self.assertEqual(cpu_device.device_id(), 0) + self.assertEqual(cpu_device.device_type(), 0) + self.assertEqual(cpu_device.device_vendor_id(), 0) + self.assertEqual(cpu_device.device_mem_type(), 0) + + def test_ort_memory_info(self): + cpu_memory_info = onnxrt.OrtMemoryInfo( + "Cpu", + onnxrt.OrtAllocatorType.ORT_ARENA_ALLOCATOR, + 0, + onnxrt.OrtMemType.DEFAULT, + ) + self.assertEqual(cpu_memory_info.name, "Cpu") + self.assertEqual(cpu_memory_info.device_id, 0) + self.assertEqual(cpu_memory_info.mem_type, onnxrt.OrtMemType.DEFAULT) + self.assertEqual(cpu_memory_info.allocator_type, onnxrt.OrtAllocatorType.ORT_ARENA_ALLOCATOR) + self.assertEqual(cpu_memory_info.device_mem_type, onnxrt.OrtDeviceMemoryType.DEFAULT) + self.assertEqual(cpu_memory_info.device_vendor_id, 0) + + def test_ort_memory_info_create_v2(self): + cpu_memory_info = onnxrt.OrtMemoryInfo.create_v2( + "Test", + onnxrt.OrtMemoryInfoDeviceType.CPU, + 0, # vendor_id + 0, # device_id + onnxrt.OrtDeviceMemoryType.DEFAULT, + 128, # alignment + onnxrt.OrtAllocatorType.ORT_ARENA_ALLOCATOR, + ) + self.assertEqual(cpu_memory_info.name, "Test") + self.assertEqual(cpu_memory_info.device_id, 0) + self.assertEqual(cpu_memory_info.mem_type, onnxrt.OrtMemType.DEFAULT) + self.assertEqual(cpu_memory_info.allocator_type, onnxrt.OrtAllocatorType.ORT_ARENA_ALLOCATOR) + self.assertEqual(cpu_memory_info.device_mem_type, onnxrt.OrtDeviceMemoryType.DEFAULT) + self.assertEqual(cpu_memory_info.device_vendor_id, 0) + def test_shared_allocator_using_create_and_register_allocator(self): # Create and register an arena based allocator diff --git a/onnxruntime/test/python/onnxruntime_test_python_autoep.py b/onnxruntime/test/python/onnxruntime_test_python_autoep.py index cb31627a87c48..d66951bd66f3d 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_autoep.py +++ b/onnxruntime/test/python/onnxruntime_test_python_autoep.py @@ -226,6 +226,16 @@ def test_example_plugin_ep_devices(self): hw_metadata = hw_device.metadata self.assertGreater(len(hw_metadata), 0) # Should have at least SPDRP_HARDWAREID on Windows + test_mem_info = test_ep_device.memory_info(onnxrt.OrtDeviceMemoryType.DEFAULT) + self.assertIsNotNone(test_mem_info) + del test_mem_info + + test_sync_stream = test_ep_device.create_sync_stream() + self.assertIsNotNone(test_sync_stream) + stream_handle = test_sync_stream.get_handle() + self.assertIsNotNone(stream_handle) + del test_sync_stream + # Add EP plugin's OrtEpDevice to the SessionOptions. sess_options = onnxrt.SessionOptions() sess_options.add_provider_for_devices([test_ep_device], {"opt1": "val1"}) @@ -282,6 +292,55 @@ def test_example_plugin_ep_data_transfer(self): self.unregister_execution_provider_library(ep_name) + def test_copy_tensors(self): + """ + Test global api copy_tensors between OrtValue objects + using EP plug-in data transfer + """ + if sys.platform != "win32": + self.skipTest("Skipping test because device discovery is only supported on Windows") + + ep_lib_path = "example_plugin_ep.dll" + try: + ep_lib_path = get_name("example_plugin_ep.dll") + except FileNotFoundError: + self.skipTest(f"Skipping test because EP library '{ep_lib_path}' cannot be found") + + ep_name = "example_ep" + self.register_execution_provider_library(ep_name, os.path.realpath(ep_lib_path)) + + # Generate 2 numpy arrays + a = np.random.rand(3, 2).astype(np.float32) + b = np.random.rand(3, 2).astype(np.float32) + + # Create OrtValue from numpy arrays on EP device + # the example EP pretends to use GPU memory, so we place it there + a_device = onnxrt.OrtValue.ortvalue_from_numpy(a, "gpu", 0, 0xBE57) + b_device = onnxrt.OrtValue.ortvalue_from_numpy(b, "gpu", 0, 0xBE57) + + # Create destination ort values with the same shape on CPU + a_cpu_copy = onnxrt.OrtValue.ortvalue_from_shape_and_type(a.shape, a.dtype) + b_cpu_copy = onnxrt.OrtValue.ortvalue_from_shape_and_type(b.shape, b.dtype) + + # source list + src_list = [a_device, b_device] + dst_list = [a_cpu_copy, b_cpu_copy] + # Passing None for stream as we copy between CPU + # Test None because it is allowed + onnxrt.copy_tensors(src_list, dst_list, None) + + # Release the OrtValue on the EP device + # before the EP library is unregistered + del src_list + del a_device + del b_device + + # Verify the contents + np.testing.assert_array_equal(a, a_cpu_copy.numpy()) + np.testing.assert_array_equal(b, b_cpu_copy.numpy()) + + self.unregister_execution_provider_library(ep_name) + if __name__ == "__main__": unittest.main(verbosity=1) diff --git a/onnxruntime/test/python/transformers/test_moe_cpu.py b/onnxruntime/test/python/transformers/test_moe_cpu.py new file mode 100644 index 0000000000000..d6cbcc64733d4 --- /dev/null +++ b/onnxruntime/test/python/transformers/test_moe_cpu.py @@ -0,0 +1,473 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +# +# Regular MoE CPU kernel testing implementation - SwiGLU Interleaved Only +# +# This file tests the non-quantized MoE CPU implementation with SwiGLU +# activation in interleaved format and validates parity between +# PyTorch reference implementation and ONNX Runtime CPU kernel. +# +# Based on the CUDA test structure for consistency. +# -------------------------------------------------------------------------- + +import itertools +import time +import unittest + +import numpy +import torch +import torch.nn as nn +import torch.nn.functional as F +from onnx import TensorProto, helper +from parameterized import parameterized + +from onnxruntime import InferenceSession, SessionOptions + +# Device and provider settings for CPU +device = torch.device("cpu") +ort_provider = ["CPUExecutionProvider"] + +torch.manual_seed(42) +numpy.random.seed(42) + +onnx_to_torch_type_map = { + TensorProto.FLOAT16: torch.float16, + TensorProto.FLOAT: torch.float, + TensorProto.BFLOAT16: torch.bfloat16, + TensorProto.UINT8: torch.uint8, +} + +ort_to_numpy_type_map = { + TensorProto.FLOAT16: numpy.float16, + TensorProto.FLOAT: numpy.float32, + TensorProto.UINT8: numpy.uint8, +} + +ort_dtype_name_map = { + TensorProto.FLOAT16: "FP16", + TensorProto.FLOAT: "FP32", + TensorProto.BFLOAT16: "BF16", +} + + +class SwigluMoeConfig: + def __init__( + self, + hidden_size=2048, + intermediate_size=2048, + num_experts_per_token=2, + num_local_experts=8, + ): + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_experts_per_token = num_experts_per_token + self.num_local_experts = num_local_experts + + +def swiglu(x: torch.Tensor, alpha: float = 1.702, limit: float = 7.0): + x_glu = x[..., ::2] + x_linear = x[..., 1::2] + + if limit is not None: + x_glu = x_glu.clamp(max=limit) + x_linear = x_linear.clamp(min=-limit, max=limit) + + y = x_glu * torch.sigmoid(alpha * x_glu) * (x_linear + 1) + return y + + +class SwigluMlp(nn.Module): + def __init__(self, config): + super().__init__() + self.intermediate_size = config.intermediate_size + self.hidden_dim = config.hidden_size + self.w1 = nn.Linear(self.hidden_dim, 2 * self.intermediate_size, bias=True) + self.w2 = nn.Linear(self.intermediate_size, self.hidden_dim, bias=True) + + def forward(self, x): + x1 = self.w1(x) + y = swiglu(x1) + y = self.w2(y) + return y + + +def make_onnx_intializer(name: str, tensor: torch.Tensor, shape, onnx_dtype): + torch_dtype = onnx_to_torch_type_map[onnx_dtype] + if torch_dtype == torch.bfloat16: + numpy_vals_uint16 = tensor.to(torch.bfloat16).cpu().view(torch.uint16).numpy() + initializer = helper.make_tensor( + name=name, + data_type=TensorProto.BFLOAT16, + dims=shape, + vals=numpy_vals_uint16.tobytes(), + raw=True, + ) + else: + initializer = helper.make_tensor( + name=name, + data_type=onnx_dtype, + dims=shape, + vals=tensor.flatten().detach().cpu().numpy().astype(numpy.uint8).tolist() + if onnx_dtype == TensorProto.UINT8 + else tensor.detach().to(torch_dtype).flatten().tolist(), + raw=False, + ) + return initializer + + +def create_swiglu_moe_onnx_graph( + num_tokens: int, + num_experts: int, + hidden_size: int, + inter_size: int, + topk: int, + onnx_dtype: int, + quant_bits: int, + fc1_experts_weights: torch.Tensor, + fc1_experts_bias: torch.Tensor, + fc2_experts_weights: torch.Tensor, + fc2_experts_bias: torch.Tensor, + fc1_experts_weight_scale: torch.Tensor = None, + fc2_experts_weight_scale: torch.Tensor = None, +): + use_quant = quant_bits > 0 + op_name = "QMoE" if use_quant else "MoE" + + inputs = ( + [ + "input", + "router_probs", + "fc1_experts_weights", + "fc1_experts_weight_scale", + "fc1_experts_bias", + "fc2_experts_weights", + "fc2_experts_weight_scale", + "fc2_experts_bias", + ] + if use_quant + else [ + "input", + "router_probs", + "fc1_experts_weights", + "fc1_experts_bias", + "fc2_experts_weights", + "fc2_experts_bias", + ] + ) + + nodes = [ + helper.make_node( + op_name, + inputs, + ["output"], + "MoE_0", + k=topk, + normalize_routing_weights=1, + activation_type="swiglu", + swiglu_fusion=1, + activation_alpha=1.702, + activation_beta=1.0, + domain="com.microsoft", + ), + ] + + if use_quant: + nodes[0].attribute.extend([helper.make_attribute("expert_weight_bits", quant_bits)]) + + components = 2 if quant_bits == 4 else 1 + fc1_weight_shape = [num_experts, 2 * inter_size, hidden_size // components] + fc1_bias_shape = [num_experts, 2 * inter_size] + fc1_experts_weight_scale_shape = [num_experts, 2 * inter_size] + + fc2_weight_shape = [num_experts, hidden_size, inter_size // components] + fc2_bias_shape = [num_experts, hidden_size] + fc2_experts_weight_scale_shape = [num_experts, hidden_size] + + weight_onnx_type = TensorProto.UINT8 if use_quant else onnx_dtype + + torch_dtype = onnx_to_torch_type_map[onnx_dtype] + weight_torch_dtype = onnx_to_torch_type_map[weight_onnx_type] + + initializers = [ + make_onnx_intializer( + "fc1_experts_weights", fc1_experts_weights.to(weight_torch_dtype), fc1_weight_shape, weight_onnx_type + ), + make_onnx_intializer("fc1_experts_bias", fc1_experts_bias.to(torch_dtype), fc1_bias_shape, onnx_dtype), + make_onnx_intializer( + "fc2_experts_weights", fc2_experts_weights.to(weight_torch_dtype), fc2_weight_shape, weight_onnx_type + ), + make_onnx_intializer("fc2_experts_bias", fc2_experts_bias.to(torch_dtype), fc2_bias_shape, onnx_dtype), + ] + + if use_quant: + initializers.extend( + [ + make_onnx_intializer( + "fc1_experts_weight_scale", + fc1_experts_weight_scale.to(torch_dtype), + fc1_experts_weight_scale_shape, + onnx_dtype, + ), + make_onnx_intializer( + "fc2_experts_weight_scale", + fc2_experts_weight_scale.to(torch_dtype), + fc2_experts_weight_scale_shape, + onnx_dtype, + ), + ] + ) + + graph_inputs = [ + helper.make_tensor_value_info("input", onnx_dtype, [num_tokens, hidden_size]), + ] + + graph_inputs.append( + helper.make_tensor_value_info( + "router_probs", + onnx_dtype, + [num_tokens, num_experts], + ) + ) + + graph_outputs = [ + helper.make_tensor_value_info("output", onnx_dtype, [num_tokens, hidden_size]), + ] + + graph = helper.make_graph( + nodes, + "MoE_Graph", + graph_inputs, + graph_outputs, + initializers, + ) + + model = helper.make_model(graph) + return model.SerializeToString() + + +class SparseMoeBlockORTHelper(nn.Module): + def __init__(self, quant_bits=0, onnx_dtype=None): + super().__init__() + self.quant_bits = quant_bits + if onnx_dtype is None: + self.onnx_dtype = TensorProto.FLOAT16 if self.quant_bits > 0 else TensorProto.FLOAT + else: + self.onnx_dtype = onnx_dtype + self.np_type = numpy.float16 if self.onnx_dtype == TensorProto.FLOAT16 else numpy.float32 + + def create_ort_session(self, moe_onnx_graph): + sess_options = SessionOptions() + sess_options.log_severity_level = 2 + + try: + ort_session = InferenceSession(moe_onnx_graph, sess_options, providers=ort_provider) + except Exception as e: + print(f"Failed to create ONNX Runtime session with provider {ort_provider}: {e}") + print("Skipping ONNX Runtime execution for this test case.") + return None + + return ort_session + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + pass + + def ort_forward(self, hidden_states: torch.Tensor, enable_performance_test=False) -> torch.Tensor: + if self.ort_sess is None: + return None + + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states_flat = hidden_states.view(-1, hidden_dim) + router_logits = self.gate(hidden_states_flat) + + torch_dtype = onnx_to_torch_type_map[self.onnx_dtype] + + tensors = { + "input": hidden_states_flat.clone().to(device=device, dtype=torch_dtype), + "router_probs": router_logits.clone().to(device=device, dtype=torch_dtype), + "output": torch.zeros_like(hidden_states_flat, device=device, dtype=torch_dtype), + } + + ort_inputs = { + "input": tensors["input"].detach().cpu().numpy(), + "router_probs": tensors["router_probs"].detach().cpu().numpy(), + } + + if enable_performance_test: + repeat = 1000 + s = time.time() + for _ in range(repeat): + self.ort_sess.run(None, ort_inputs) + e = time.time() + print(f"MoE CPU kernel time: {(e - s) / repeat * 1000} ms") + ort_outputs = self.ort_sess.run(None, ort_inputs) + else: + ort_outputs = self.ort_sess.run(None, ort_inputs) + + output_tensor = torch.from_numpy(ort_outputs[0]).to(device) + + return output_tensor.reshape(batch_size, sequence_length, hidden_dim) + + def parity_check(self): + hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim).to(device) + torch_output = self.forward(hidden_state) + ort_output = self.ort_forward(hidden_state) + + dtype_str = ort_dtype_name_map[self.onnx_dtype] + + ort_dtype_quant_bits_tolerance_map = { + "FP32:0": (5e-3, 1e-3), + } + + atol, rtol = ort_dtype_quant_bits_tolerance_map[f"{dtype_str}:{self.quant_bits}"] + if ort_output is not None: + print( + f"name: {self.__class__.__name__}, quant_bits: {self.quant_bits}, dtype: {dtype_str}," + f" batch: {self.batch_size}, seq_len: {self.sequence_length}," + f" max_diff: {(torch_output.cpu() - ort_output.cpu()).abs().max()}" + ) + torch.testing.assert_close( + ort_output.cpu().to(torch.float32), torch_output.cpu().to(torch.float32), rtol=rtol, atol=atol + ) + + def benchmark_ort(self): + hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim).to(device) + self.ort_forward(hidden_state, enable_performance_test=True) + + +class SwigluMoEBlock(SparseMoeBlockORTHelper): + def __init__( + self, config: SwigluMoeConfig, batch_size: int, sequence_length: int, quant_bits: int = 0, onnx_dtype=None + ): + super().__init__(quant_bits, onnx_dtype=onnx_dtype) + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_token + + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=True) + + self.experts = nn.ModuleList([SwigluMlp(config) for _ in range(self.num_experts)]) + + fc1_w_list, fc2_w_list = [], [] + fc1_b_list, fc2_b_list = [], [] + + for expert in self.experts: + w1_weight = expert.w1.weight.data.clone() + w2_weight = expert.w2.weight.data.clone() + w1_bias = expert.w1.bias.data.clone() + w2_bias = expert.w2.bias.data.clone() + + fc1_w_list.append(w1_weight) + fc2_w_list.append(w2_weight) + fc1_b_list.append(w1_bias) + fc2_b_list.append(w2_bias) + + fc1_experts_weights = torch.stack(fc1_w_list, dim=0) + fc2_experts_weights = torch.stack(fc2_w_list, dim=0) + fc1_experts_bias = torch.stack(fc1_b_list, dim=0) + fc2_experts_bias = torch.stack(fc2_b_list, dim=0) + + self.batch_size = batch_size + self.sequence_length = sequence_length + + self.moe_onnx_graph = create_swiglu_moe_onnx_graph( + num_tokens=self.batch_size * self.sequence_length, + num_experts=self.num_experts, + hidden_size=self.hidden_dim, + inter_size=self.ffn_dim, + topk=self.top_k, + onnx_dtype=self.onnx_dtype, + quant_bits=self.quant_bits, + fc1_experts_weights=fc1_experts_weights, + fc1_experts_bias=fc1_experts_bias, + fc2_experts_weights=fc2_experts_weights, + fc2_experts_bias=fc2_experts_bias, + ) + + self.ort_sess = self.create_ort_session(self.moe_onnx_graph) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + router_logits = self.gate(hidden_states) + + # Compute full softmax over all experts (same as CUDA) + full_probs = F.softmax(router_logits, dim=-1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(full_probs, self.top_k, dim=-1) + + # For normalize_routing_weights=1: normalize by sum of top-k values (same as CUDA) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + if top_x.shape[0] == 0: + continue + + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states + + +swiglu_test_cases = list( + itertools.product( + [1, 2], # batch_size + [16, 32], # sequence_length + [0], # quant_bits (CPU kernel only supports float32) + ) +) + +perf_test_cases = list( + itertools.product( + [1], # batch_size + [128], # sequence_length + [0], # quant_bits (CPU kernel only supports float32) + ) +) + + +class TestSwigluMoECPU(unittest.TestCase): + @parameterized.expand(swiglu_test_cases) + def test_swiglu_moe_parity(self, batch_size, sequence_length, quant_bits): + config = SwigluMoeConfig(hidden_size=64, intermediate_size=256, num_experts_per_token=2, num_local_experts=4) + moe = SwigluMoEBlock(config, batch_size, sequence_length, quant_bits) + moe.to(device) + moe.parity_check() + + +class TestSwigluMoECPUPerf(unittest.TestCase): + @parameterized.expand(perf_test_cases) + def test_swiglu_moe_perf(self, batch_size, sequence_length, quant_bits): + hidden_size = 1024 + intermediate_size = 2048 + num_experts_per_token = 4 + num_local_experts = 16 + config = SwigluMoeConfig( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_experts_per_token=num_experts_per_token, + num_local_experts=num_local_experts, + ) + moe = SwigluMoEBlock(config, batch_size, sequence_length, quant_bits) + moe.to(device) + moe.benchmark_ort() + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/transformers/test_qmoe_cpu.py b/onnxruntime/test/python/transformers/test_qmoe_cpu.py index efaaca29a01b6..0292111b16962 100644 --- a/onnxruntime/test/python/transformers/test_qmoe_cpu.py +++ b/onnxruntime/test/python/transformers/test_qmoe_cpu.py @@ -128,6 +128,148 @@ def quant_dequant(weights, is_4_bit_quantization: bool = True): # Calculate scale like C++ implementation abs_max = weights.abs().max(dim=-1, keepdim=True)[0] + + # Set minimum scale to avoid division by zero + scale = torch.clamp(abs_max, min=1e-6) + + # Quantization ranges for symmetric quantization + if is_4_bit_quantization: + qmin, qmax = -8, 7 + zero_point = 8 # Offset to make values unsigned + else: + qmin, qmax = -128, 127 + zero_point = 128 # Offset to make values unsigned + + # Quantize using double precision division and C-like rounding (half away from zero) + scaled = weights.double() / scale.double() + sign = torch.sign(scaled) + abs_scaled = torch.abs(scaled) + quant_rounded = torch.floor(abs_scaled + 0.5) + quantized = torch.clamp((sign * quant_rounded).to(torch.int32), qmin, qmax).to(weights.dtype) + + # Convert to unsigned and pack for storage + if is_4_bit_quantization: + # Convert to unsigned 4-bit and pack into uint8 + unsigned_quantized = (quantized + zero_point).to(torch.uint8) + + # Pack two 4-bit values into one uint8 + packed_size = (weights.shape[-1] + 1) // 2 + packed_quantized = torch.zeros((*weights.shape[:-1], packed_size), dtype=torch.uint8, device=weights.device) + + for i in range(0, weights.shape[-1], 2): + val1 = unsigned_quantized[..., i] + val2 = unsigned_quantized[..., i + 1] if i + 1 < weights.shape[-1] else torch.zeros_like(val1) + packed_quantized[..., i // 2] = (val1 & 0xF) | ((val2 & 0xF) << 4) + + quantized_storage = packed_quantized + else: + # 8-bit: convert to unsigned uint8 + quantized_storage = (quantized + zero_point).to(torch.uint8) + + # Dequantize for verification (use float32 scale for higher precision) + dequantized = quantized.to(torch.float32) * scale + + return scale.squeeze(-1).to(torch.float32), quantized_storage, dequantized + + +def quant_dequant_blockwise(weights, block_size, is_4_bit_quantization: bool = True): + """ + Block-wise quantization and dequantization for testing purposes. + This function uses symmetric quantization centered around 0 (no zero-point). + + Args: + weights: Input tensor of shape [rows, cols] + block_size: Size of each quantization block + is_4_bit_quantization: Whether to use 4-bit (True) or 8-bit (False) quantization + + Returns: + scales: Scale tensor of shape [rows, num_blocks] + quantized: Quantized tensor + dequantized: Dequantized tensor for verification + """ + rows, cols = weights.shape + num_blocks = (cols + block_size - 1) // block_size + + # Handle edge case of all-zero weights tensor + if torch.all(weights == 0): + scales = torch.zeros((rows, num_blocks), dtype=torch.float16, device=weights.device) + if is_4_bit_quantization: + packed_size = (cols + 1) // 2 + quantized = torch.zeros((rows, packed_size), dtype=torch.uint8, device=weights.device) + else: + quantized = torch.zeros((rows, cols), dtype=torch.uint8, device=weights.device) + dequantized = torch.zeros_like(weights) + return scales, quantized, dequantized + + # Initialize output tensors; use float32 for scales to reduce precision loss + scales = torch.zeros((rows, num_blocks), dtype=torch.float32, device=weights.device) + dequantized = torch.zeros_like(weights) + + # Quantization ranges and zero point + if is_4_bit_quantization: + qmin, qmax = -8, 7 + zero_point = 8 + packed_size = (cols + 1) // 2 + quantized = torch.zeros((rows, packed_size), dtype=torch.uint8, device=weights.device) + else: + qmin, qmax = -128, 127 + zero_point = 128 + quantized = torch.zeros((rows, cols), dtype=torch.uint8, device=weights.device) + + # Process each block with higher-precision math to match C++ behavior + for row in range(rows): + for block_idx in range(num_blocks): + start_col = block_idx * block_size + end_col = min(start_col + block_size, cols) + + # Get block data + block_data = weights[row, start_col:end_col] + + # Calculate absolute max and ensure small epsilon to avoid div-by-zero + abs_max = block_data.abs().max() + abs_max = torch.clamp(abs_max, min=1e-8) + + # Compute scale consistent with C++: use 7.0 for 4-bit positive max, 127.0 for 8-bit + if is_4_bit_quantization: + # Use higher precision then keep as float32 for scale + scale = (abs_max.double() / 7.0).float() + 1e-12 + else: + scale = (abs_max.double() / 127.0).float() + 1e-12 + + scales[row, block_idx] = scale.to(torch.float32) + + if scale == 0: + continue + + # Quantize using double precision for the division to reduce rounding error + scaled = block_data.double() / scale.double() + # Emulate C's round() behavior (round half away from zero) to match C++ implementation + sign = torch.sign(scaled) + abs_scaled = torch.abs(scaled) + quant_rounded = torch.floor(abs_scaled + 0.5) + quantized_block = (sign * quant_rounded).clamp(qmin, qmax).to(torch.int32) + + # Pack for 4-bit or store directly for 8-bit + if is_4_bit_quantization: + for i in range(0, end_col - start_col, 2): + col_idx = start_col + i + packed_idx = col_idx // 2 + + val1 = int(quantized_block[i]) + zero_point + val2 = int(quantized_block[i + 1]) + zero_point if i + 1 < len(quantized_block) else zero_point + + # Pack two 4-bit values into one uint8 + packed_val = (val1 & 0xF) | ((val2 & 0xF) << 4) + quantized[row, packed_idx] = packed_val + else: + quantized_vals = (quantized_block + zero_point).to(torch.uint8) + quantized[row, start_col:end_col] = quantized_vals + + # Dequantize for verification (signed quantized values multiplied by scale) + signed = quantized_block.to(torch.float32) + dequantized[row, start_col:end_col] = signed * scale + + return scales, quantized, dequantized abs_max = torch.clamp(abs_max, min=1e-8) # More conservative clamping for better precision if is_4_bit_quantization: @@ -247,6 +389,7 @@ def create_cpu_moe_onnx_graph( use_quant=False, quant_bits=4, swiglu_interleaved=False, + block_size=0, # New parameter for block-wise quantization ): if not has_onnx: return None @@ -254,7 +397,8 @@ def create_cpu_moe_onnx_graph( inter_size = intermediate_size topk = top_k - use_quant = True + # Only override use_quant for backward compatibility if not explicitly set + # use_quant = True # This line was causing issues for regular MoE tests if fc1_scales is None and use_quant: return None @@ -267,26 +411,47 @@ def create_cpu_moe_onnx_graph( assert fc2_experts_weights.dtype == torch.uint8, "FC2 weights must be uint8 for QMoE" assert fc1_scales is not None, "FC1 scales must be provided for QMoE" assert fc2_scales is not None, "FC2 scales must be provided for QMoE" - assert fc1_scales.dtype == torch.float16, "FC1 scales must be float16 for QMoE" - assert fc2_scales.dtype == torch.float16, "FC2 scales must be float16 for QMoE" + # Accept float16 or float32 scales; tests may produce float32 for better precision + assert fc1_scales.dtype in (torch.float16, torch.float32), "FC1 scales must be float16 or float32 for QMoE" + assert fc2_scales.dtype in (torch.float16, torch.float32), "FC2 scales must be float16 or float32 for QMoE" if not has_onnx: return None - op_name = "QMoE" - inputs = [ - "input", - "router_probs", - "fc1_experts_weights", - "fc1_scales", - "", - "fc2_experts_weights", - "fc2_scales", - "", - ] + # Set operator name and inputs based on quantization mode + if use_quant: + op_name = "QMoE" + inputs = [ + "input", + "router_probs", + "fc1_experts_weights", + "fc1_scales", + "", + "fc2_experts_weights", + "fc2_scales", + "", + ] + else: + # For regular (non-quantized) MoE, use different operator and input layout + op_name = "MoE" # Regular MoE operator + inputs = [ + "input", + "router_probs", + "fc1_experts_weights", + "fc1_experts_bias" if fc1_bias is not None else "", # fc1_bias as input 3 + "fc2_experts_weights", + "fc2_experts_bias" if fc2_bias is not None else "", # fc2_bias as input 5 + "", # fc3_experts_weights (not used) + "", # fc3_experts_bias (not used) + ] activation = "swiglu" if use_swiglu else "silu" + # Set normalization behavior based on operator type: + # - QMoE: Raw logits passed, needs normalization in C++ kernel + # - Regular MoE: Pre-computed probabilities passed, no additional normalization needed + normalize_routing = 1 if use_quant else 0 + nodes = [ helper.make_node( op_name, @@ -294,13 +459,14 @@ def create_cpu_moe_onnx_graph( ["output"], "MoE_0", k=topk, - normalize_routing_weights=1, # Use proper routing normalization to match PyTorch behavior + normalize_routing_weights=normalize_routing, activation_type=activation, # Add new attributes with backwards-compatible default values - swiglu_fusion=1 if (use_swiglu and swiglu_interleaved) else 0, # 1 = fused and interleaved + swiglu_fusion=1 if use_swiglu else 0, # 1 if using SwiGLU activation swiglu_limit=7.0, activation_alpha=1.702, activation_beta=1.0, + swiglu_interleaved=1 if swiglu_interleaved else 0, # Enable this attribute domain="com.microsoft", ), ] @@ -308,6 +474,10 @@ def create_cpu_moe_onnx_graph( if use_quant: nodes[0].attribute.extend([helper.make_attribute("expert_weight_bits", quant_bits)]) + # Add block_size attribute for block-wise quantization + if block_size > 0: + nodes[0].attribute.extend([helper.make_attribute("block_size", block_size)]) + # Weights are store in column major order. Need pack 2 int4 values into uint8. # Use the actual tensor shapes instead of calculating them to avoid size mismatches fc1_shape = list(fc1_experts_weights.shape) @@ -318,79 +488,113 @@ def create_cpu_moe_onnx_graph( weight_numpy_type = numpy.uint8 if use_quant else ort_to_numpy_type_map[onnx_dtype] weight_onnx_type = TensorProto.UINT8 if use_quant else onnx_dtype + # Use raw bytes from C-contiguous numpy arrays to ensure the exact memory layout + # of the packed uint8 weight tensors is preserved when writing the ONNX initializer. + fc1_np = fc1_experts_weights.detach().cpu().numpy().astype(weight_numpy_type) + fc2_np = fc2_experts_weights.detach().cpu().numpy().astype(weight_numpy_type) + fc1_np = numpy.ascontiguousarray(fc1_np) + fc2_np = numpy.ascontiguousarray(fc2_np) + initializers = [ helper.make_tensor( "fc1_experts_weights", weight_onnx_type, fc1_shape, - fc1_experts_weights.flatten().detach().cpu().numpy().astype(weight_numpy_type).tolist(), + fc1_np.tobytes(), + raw=True, ), helper.make_tensor( "fc2_experts_weights", weight_onnx_type, fc2_shape, - fc2_experts_weights.flatten().detach().cpu().numpy().astype(weight_numpy_type).tolist(), + fc2_np.tobytes(), + raw=True, ), ] - fc1_scale_shape = [num_experts, 2 * inter_size if use_swiglu else inter_size] - fc2_scale_shape = [num_experts, hidden_size] + # Calculate scale tensor shapes based on block_size + if block_size > 0: + # Block-wise quantization: 3D scale tensors + fc1_blocks_per_row = (hidden_size + block_size - 1) // block_size + fc2_blocks_per_row = (inter_size + block_size - 1) // block_size + + fc1_scale_shape = [num_experts, 2 * inter_size if use_swiglu else inter_size, fc1_blocks_per_row] + fc2_scale_shape = [num_experts, hidden_size, fc2_blocks_per_row] + + fc1_scale_size = num_experts * (2 * inter_size if use_swiglu else inter_size) * fc1_blocks_per_row + fc2_scale_size = num_experts * hidden_size * fc2_blocks_per_row + else: + # Row-wise quantization: 2D scale tensors + fc1_scale_shape = [num_experts, 2 * inter_size if use_swiglu else inter_size] + fc2_scale_shape = [num_experts, hidden_size] - fc1_scale_size = num_experts * (2 * inter_size if use_swiglu else inter_size) - fc2_scale_size = num_experts * hidden_size + fc1_scale_size = num_experts * (2 * inter_size if use_swiglu else inter_size) + fc2_scale_size = num_experts * hidden_size # Handle scale tensors - fc1_scales and fc2_scales are guaranteed to be not None due to earlier assertions - # Handle different possible scale tensor structures for fc1_scales - if len(fc1_scales.shape) == 4: - # 4D case: [num_experts, inter_size, hidden_size, 1] - extract first scale per expert per output - if use_swiglu: - fc1_scale_tensor = fc1_scales.to(torch_dtype)[:, : 2 * inter_size, 0, 0].flatten().detach().cpu().numpy() - else: - fc1_scale_tensor = fc1_scales.to(torch_dtype)[:, :inter_size, 0, 0].flatten().detach().cpu().numpy() - elif len(fc1_scales.shape) == 2: - # 2D case: already flattened, just ensure correct size + # Process scale tensors based on whether block-wise quantization is used + if block_size > 0: + # For block-wise quantization, the scales are already in the correct 3D shape + # [num_experts, output_features, num_blocks] from quant_dequant_blockwise + # Convert scales to the selected ONNX dtype (prefer float32 for higher precision) fc1_scale_tensor = fc1_scales.to(torch_dtype).flatten().detach().cpu().numpy() - if use_swiglu and fc1_scale_tensor.size == num_experts * inter_size: - # For SwiGLU, duplicate the scales to cover both gate and value components - fc1_scale_tensor = numpy.tile(fc1_scale_tensor.reshape(num_experts, inter_size), (1, 2)).flatten() - elif fc1_scale_tensor.size > fc1_scale_size: - # Truncate to expected size - fc1_scale_tensor = fc1_scale_tensor[:fc1_scale_size] + fc2_scale_tensor = fc2_scales.to(torch_dtype).flatten().detach().cpu().numpy() else: - # Other cases: flatten and truncate/pad as needed - fc1_scale_tensor = fc1_scales.to(torch_dtype).flatten().detach().cpu().numpy() - if fc1_scale_tensor.size > fc1_scale_size: - fc1_scale_tensor = fc1_scale_tensor[:fc1_scale_size] - elif fc1_scale_tensor.size < fc1_scale_size: - # Pad with ones if too small - pad_size = fc1_scale_size - fc1_scale_tensor.size - fc1_scale_tensor = numpy.concatenate([fc1_scale_tensor, numpy.ones(pad_size, dtype=fc1_scale_tensor.dtype)]) - - # Process scale tensor for proper shape + # For row-wise quantization, handle different possible scale tensor structures for fc1_scales + if len(fc1_scales.shape) == 4: + # 4D case: [num_experts, inter_size, hidden_size, 1] - extract first scale per expert per output + if use_swiglu: + fc1_scale_tensor = ( + fc1_scales.to(torch_dtype)[:, : 2 * inter_size, 0, 0].flatten().detach().cpu().numpy() + ) + else: + fc1_scale_tensor = fc1_scales.to(torch_dtype)[:, :inter_size, 0, 0].flatten().detach().cpu().numpy() + elif len(fc1_scales.shape) == 2: + # 2D case: already flattened, just ensure correct size + fc1_scale_tensor = fc1_scales.to(torch_dtype).flatten().detach().cpu().numpy() + if use_swiglu and fc1_scale_tensor.size == num_experts * inter_size: + # For SwiGLU, duplicate the scales to cover both gate and value components + fc1_scale_tensor = numpy.tile(fc1_scale_tensor.reshape(num_experts, inter_size), (1, 2)).flatten() + elif fc1_scale_tensor.size > fc1_scale_size: + # Truncate to expected size + fc1_scale_tensor = fc1_scale_tensor[:fc1_scale_size] + else: + # Other cases: flatten and truncate/pad as needed + fc1_scale_tensor = fc1_scales.to(torch_dtype).flatten().detach().cpu().numpy() + if fc1_scale_tensor.size > fc1_scale_size: + fc1_scale_tensor = fc1_scale_tensor[:fc1_scale_size] + elif fc1_scale_tensor.size < fc1_scale_size: + # Pad with ones if too small + pad_size = fc1_scale_size - fc1_scale_tensor.size + fc1_scale_tensor = numpy.concatenate( + [fc1_scale_tensor, numpy.ones(pad_size, dtype=fc1_scale_tensor.dtype)] + ) + + # Handle different possible scale tensor structures for fc2_scales + if len(fc2_scales.shape) == 4: + # 4D case: [num_experts, hidden_size, inter_size, 1] - extract first scale per expert per output + fc2_scale_tensor = fc2_scales.to(torch_dtype)[:, :hidden_size, 0, 0].flatten().detach().cpu().numpy() + elif len(fc2_scales.shape) == 2: + # 2D case: already flattened, just ensure correct size + fc2_scale_tensor = fc2_scales.to(torch_dtype).flatten().detach().cpu().numpy() + if fc2_scale_tensor.size > fc2_scale_size: + # Truncate to expected size + fc2_scale_tensor = fc2_scale_tensor[:fc2_scale_size] + else: + # Other cases: flatten and truncate/pad as needed + fc2_scale_tensor = fc2_scales.to(torch_dtype).flatten().detach().cpu().numpy() + if fc2_scale_tensor.size > fc2_scale_size: + fc2_scale_tensor = fc2_scale_tensor[:fc2_scale_size] + elif fc2_scale_tensor.size < fc2_scale_size: + # Pad with ones if too small + pad_size = fc2_scale_size - fc2_scale_tensor.size + fc2_scale_tensor = numpy.concatenate( + [fc2_scale_tensor, numpy.ones(pad_size, dtype=fc2_scale_tensor.dtype)] + ) + + # Process scale tensors for proper data format fc1_scale_data_list = fc1_scale_tensor.tolist() fc1_scale_data = fc1_scale_data_list - - # Handle different possible scale tensor structures for fc2_scales - if len(fc2_scales.shape) == 4: - # 4D case: [num_experts, hidden_size, inter_size, 1] - extract first scale per expert per output - fc2_scale_tensor = fc2_scales.to(torch_dtype)[:, :hidden_size, 0, 0].flatten().detach().cpu().numpy() - elif len(fc2_scales.shape) == 2: - # 2D case: already flattened, just ensure correct size - fc2_scale_tensor = fc2_scales.to(torch_dtype).flatten().detach().cpu().numpy() - if fc2_scale_tensor.size > fc2_scale_size: - # Truncate to expected size - fc2_scale_tensor = fc2_scale_tensor[:fc2_scale_size] - else: - # Other cases: flatten and truncate/pad as needed - fc2_scale_tensor = fc2_scales.to(torch_dtype).flatten().detach().cpu().numpy() - if fc2_scale_tensor.size > fc2_scale_size: - fc2_scale_tensor = fc2_scale_tensor[:fc2_scale_size] - elif fc2_scale_tensor.size < fc2_scale_size: - # Pad with ones if too small - pad_size = fc2_scale_size - fc2_scale_tensor.size - fc2_scale_tensor = numpy.concatenate([fc2_scale_tensor, numpy.ones(pad_size, dtype=fc2_scale_tensor.dtype)]) - - # Process scale tensor for proper shape fc2_scale_data_list = fc2_scale_tensor.tolist() fc2_scale_data = fc2_scale_data_list @@ -594,10 +798,7 @@ class SparseMoeBlockORTHelper(nn.Module): def __init__(self, quant_bits=0, onnx_dtype=None): super().__init__() self.quant_bits = quant_bits - if onnx_dtype is None: - self.onnx_dtype = TensorProto.FLOAT16 if self.quant_bits > 0 else TensorProto.FLOAT - else: - self.onnx_dtype = onnx_dtype + self.onnx_dtype = onnx_dtype self.np_type = numpy.float16 if self.onnx_dtype == TensorProto.FLOAT16 else numpy.float32 def create_ort_session(self, moe_onnx_graph): @@ -619,18 +820,55 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def ort_forward(self, hidden_states: torch.Tensor, enable_performance_test=False) -> torch.Tensor: if self.ort_sess is None: + print(f"ERROR: ORT session is None for {self.__class__.__name__}") return None batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states_flat = hidden_states.view(-1, hidden_dim) router_logits = self.gate(hidden_states_flat) + # Different routing logic for QMoE vs regular MoE: + # - QMoE expects raw logits (does its own softmax internally) + # - Regular MoE expects pre-computed routing probabilities + if hasattr(self, "quant_bits") and self.quant_bits > 0: + # QMoE: Pass raw logits directly (QMoE does softmax internally) + router_input = router_logits + # print("DEBUG: Using QMoE routing (raw logits)") + else: + # Regular MoE: Apply the same routing logic as PyTorch reference + # This converts raw logits to proper routing probabilities + routing_weights, selected_experts = masked_sampling_omp_inference( + router_logits, + top_k=self.top_k, + jitter_eps=self.router_jitter_noise, + training=False, + ) + + # IMPORTANT: The routing weights from masked_sampling_omp_inference sum to top_k, + # but ONNX Runtime expects normalized probabilities that sum to 1.0 + # Normalize the routing weights per token + routing_weights = routing_weights / routing_weights.sum(dim=1, keepdim=True) + + # Create proper router probabilities tensor that matches PyTorch routing + router_input = torch.zeros_like(router_logits) + for i in range(router_logits.shape[0]): # For each token + for j in range(self.top_k): # For each top-k expert + expert_idx = selected_experts[i, j] + router_input[i, expert_idx] = routing_weights[i, j] + + # print("DEBUG: Using regular MoE routing (processed probabilities)") + + # print(f"DEBUG: router_input stats: mean={router_input.mean():.6f}, std={router_input.std():.6f}") + # print( + # f"DEBUG: hidden_states_flat stats: mean={hidden_states_flat.mean():.6f}, std={hidden_states_flat.std():.6f}" + # ) + torch_dtype = onnx_to_torch_type_map[self.onnx_dtype] tensors = { "input": hidden_states_flat.clone().to(device=device, dtype=torch_dtype), "router_probs": router_logits.clone().to(device=device, dtype=torch_dtype), - "output": torch.zeros_like(hidden_states_flat, device=device, dtype=torch_dtype), + "output": torch.zeros((batch_size * sequence_length, hidden_dim), device=device, dtype=torch_dtype), } try: @@ -656,10 +894,14 @@ def ort_forward(self, hidden_states: torch.Tensor, enable_performance_test=False buffer_ptr=tensor.data_ptr(), ) + # print("DEBUG: About to run ORT inference...") + iobinding.synchronize_inputs() self.ort_sess.run_with_iobinding(iobinding) iobinding.synchronize_outputs() + # print("DEBUG: ORT inference completed successfully") + if enable_performance_test: repeat = 100 s = time.time() @@ -687,14 +929,29 @@ def recreate_onnx_model(self): is_4_bit = self.quant_bits == 4 for i in range(self.num_experts): - w1_scale, pre_qweight1, w1_qdq = quant_dequant(self.experts[i].w1.weight, is_4_bit) - w2_scale, pre_qweight2, w2_qdq = quant_dequant(self.experts[i].w2.weight, is_4_bit) + if self.block_size > 0: + # Use block-wise quantization + w1_scale, pre_qweight1, w1_qdq = quant_dequant_blockwise( + self.experts[i].w1.weight, self.block_size, is_4_bit + ) + w2_scale, pre_qweight2, w2_qdq = quant_dequant_blockwise( + self.experts[i].w2.weight, self.block_size, is_4_bit + ) + else: + # Use row-wise quantization + w1_scale, pre_qweight1, w1_qdq = quant_dequant(self.experts[i].w1.weight, is_4_bit) + w2_scale, pre_qweight2, w2_qdq = quant_dequant(self.experts[i].w2.weight, is_4_bit) if self.use_swiglu: if self.swiglu_interleaved: pass else: - w3_scale, pre_qweight3, w3_qdq = quant_dequant(self.experts[i].w3.weight, is_4_bit) + if self.block_size > 0: + w3_scale, pre_qweight3, w3_qdq = quant_dequant_blockwise( + self.experts[i].w3.weight, self.block_size, is_4_bit + ) + else: + w3_scale, pre_qweight3, w3_qdq = quant_dequant(self.experts[i].w3.weight, is_4_bit) gate_weights = pre_qweight1 value_weights = pre_qweight3 @@ -755,6 +1012,7 @@ def recreate_onnx_model(self): use_quant=True, # Always use QMoE quant_bits=self.quant_bits, swiglu_interleaved=self.swiglu_interleaved if hasattr(self, "swiglu_interleaved") else False, + block_size=self.block_size, # Add block_size for block-wise quantization ) except Exception: self.moe_onnx_graph = None @@ -803,6 +1061,45 @@ def parity_check(self): print(f"Parity check - {act_type} {self.quant_bits}-bit: max_diff = {max_diff:.6f}") + # Diagnostic dump: when differences are large, show the index and nearby values + if max_diff > 1e-3: + diff = (torch_output.cpu() - ort_output.cpu()).abs() + idx = torch.argmax(diff) + flat_idx = int(idx) + # Derive coordinates (batch, seq, hidden) from flattened index + total_elems = torch_output.numel() + # Work in flattened [batch, seq, hidden] ordering + hidden_dim = self.hidden_dim + seq = self.sequence_length + # Clamp to safe bounds + flat_idx = min(flat_idx, total_elems - 1) + i = flat_idx // (hidden_dim) + j = i // seq + k = flat_idx % hidden_dim + print( + f"Diagnostic - max diff at flat_idx={flat_idx} -> sample (batch_idx={j}, seq_idx={i % seq}, hidden_idx={k})" + ) + print("Torch sample:", torch_output.cpu().reshape(-1, hidden_dim)[i, k].item()) + print("ORT sample:", ort_output.cpu().reshape(-1, hidden_dim)[i, k].item()) + # Print routing and per-expert contributions for this token from the PyTorch reference + try: + hidden_states_flat = hidden_state.view(-1, hidden_dim) + token_vec = hidden_states_flat[i : i + 1] + gate_logits = self.gate(token_vec) + topk_vals, topk_experts = torch.topk(gate_logits, self.top_k, dim=-1) + topk_soft = F.softmax(topk_vals, dim=1) + print("Gate logits:", gate_logits.detach().cpu().numpy()) + print("Selected experts:", topk_experts.detach().cpu().numpy()) + print("Routing weights:", topk_soft.detach().cpu().numpy()) + # Compute per-expert contributions for selected experts + for idx_e, e in enumerate(topk_experts[0].tolist()): + expert_layer = self.experts[e] + expert_out = expert_layer(token_vec) + contrib = expert_out[0, k].item() * topk_soft[0, idx_e].item() + print(f"Expert {e} contrib at hidden {k}: {contrib}") + except Exception as _: + pass + ort_dtype_quant_bits_tolerance_map = { "FP32:0": (5e-3, 1e-3), "FP16:0": (5e-2, 1e-3), @@ -843,7 +1140,13 @@ def small_test_cases(): class SwigluMoEBlock(SparseMoeBlockORTHelper): def __init__( - self, config: SwigluMoeConfig, batch_size: int, sequence_length: int, quant_bits: int = 0, onnx_dtype=None + self, + config: SwigluMoeConfig, + batch_size: int, + sequence_length: int, + quant_bits: int = 0, + onnx_dtype=None, + block_size: int = 0, ): super().__init__(quant_bits, onnx_dtype=onnx_dtype) self.hidden_dim = config.hidden_size @@ -852,6 +1155,7 @@ def __init__( self.top_k = config.num_experts_per_token self.use_swiglu = True self.swiglu_interleaved = True + self.block_size = block_size # Store block_size for QMoE use_quant = self.quant_bits > 0 self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=True) @@ -921,7 +1225,13 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class PhiMoESparseMoeBlock(SparseMoeBlockORTHelper): def __init__( - self, config: PhiMoEConfig, batch_size: int, sequence_length: int, quant_bits: int = 0, onnx_dtype=None + self, + config: PhiMoEConfig, + batch_size: int, + sequence_length: int, + quant_bits: int = 0, + onnx_dtype=None, + block_size: int = 0, ): super().__init__(quant_bits, onnx_dtype=onnx_dtype) self.hidden_dim = config.hidden_size @@ -931,6 +1241,7 @@ def __init__( self.router_jitter_noise = config.router_jitter_noise self.use_swiglu = True self.swiglu_interleaved = True + self.block_size = block_size # Store block_size for QMoE use_quant = self.quant_bits > 0 self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=True) @@ -950,8 +1261,14 @@ def __init__( else: is_4_bit = self.quant_bits == 4 - scale1, pre_qweight1, w1_qdq = quant_dequant(expert.w1.weight, is_4_bit) - scale2, pre_qweight2, w2_qdq = quant_dequant(expert.w2.weight, is_4_bit) + if self.block_size > 0: + # Use block-wise quantization + scale1, pre_qweight1, w1_qdq = quant_dequant_blockwise(expert.w1.weight, self.block_size, is_4_bit) + scale2, pre_qweight2, w2_qdq = quant_dequant_blockwise(expert.w2.weight, self.block_size, is_4_bit) + else: + # Use row-wise quantization + scale1, pre_qweight1, w1_qdq = quant_dequant(expert.w1.weight, is_4_bit) + scale2, pre_qweight2, w2_qdq = quant_dequant(expert.w2.weight, is_4_bit) expert.w1.weight.data = w1_qdq expert.w2.weight.data = w2_qdq @@ -990,6 +1307,7 @@ def __init__( use_quant=use_quant, quant_bits=self.quant_bits, swiglu_interleaved=self.swiglu_interleaved, + block_size=self.block_size, # Add block_size for block-wise quantization ) self.ort_sess = self.create_ort_session(self.moe_onnx_graph) if self.moe_onnx_graph else None @@ -1001,9 +1319,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.view(-1, hidden_dim) router_logits = self.gate(hidden_states) - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # Match CPU implementation: select top-k experts by logits, then softmax over those logits + routing_weights_vals, selected_experts = torch.topk(router_logits, self.top_k, dim=-1) + routing_weights = F.softmax(routing_weights_vals, dim=1, dtype=torch.float) routing_weights = routing_weights.to(hidden_states.dtype) final_hidden_states = torch.zeros( @@ -1038,16 +1356,61 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: (2, 16, 8), ] +# Define test cases for block-wise quantization +phi3_blockwise_test_cases = [ + (1, 32, 4, 32), # batch_size, sequence_length, quant_bits, block_size + (1, 32, 8, 64), + (2, 16, 4, 32), + (2, 16, 8, 64), +] + @unittest.skipIf(disable_cpu_qmoe_tests, "Skipping qMoE cpu tests") class TestPhiQMoECPU(unittest.TestCase): @parameterized.expand(phi3_test_cases) def test_phi3_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): + # Create unique seed based on test parameters to ensure different inputs for each test + base_seed = 2000 # Different base seed from other tests + param_hash = hash((batch_size, sequence_length, quant_bits)) + unique_seed = base_seed + abs(param_hash) % 1000 + + torch.manual_seed(unique_seed) + numpy.random.seed(unique_seed) + + test_config = ( + f"batch_size={batch_size}, sequence_length={sequence_length}, quant_bits={quant_bits}, seed={unique_seed}" + ) + print(f"Running Phi3 QMoE test: {test_config}") + + config = PhiMoEConfig(hidden_size=128, intermediate_size=256, num_local_experts=4, num_experts_per_tok=2) + + phi3_moe = PhiMoESparseMoeBlock( + config, + batch_size=batch_size, + sequence_length=sequence_length, + quant_bits=quant_bits, + onnx_dtype=TensorProto.FLOAT, + ) + + hidden_states = torch.randn(batch_size, sequence_length, config.hidden_size).to(torch.float32) + + torch_result = phi3_moe.forward(hidden_states) + + # Verify output shape and basic properties + expected_shape = (batch_size, sequence_length, config.hidden_size) + self.assertEqual(torch_result.shape, expected_shape) + self.assertFalse(torch.isnan(torch_result).any()) + self.assertFalse(torch.isinf(torch_result).any()) + + phi3_moe.parity_check() + + @parameterized.expand(phi3_blockwise_test_cases) + def test_phi3_qmoe_blockwise_parity_cpu(self, batch_size, sequence_length, quant_bits, block_size): torch.manual_seed(42) numpy.random.seed(42) - test_config = f"batch_size={batch_size}, sequence_length={sequence_length}, quant_bits={quant_bits}" - print(f"Running Phi3 QMoE test: {test_config}") + test_config = f"batch_size={batch_size}, sequence_length={sequence_length}, quant_bits={quant_bits}, block_size={block_size}" + print(f"Running Phi3 QMoE block-wise test: {test_config}") config = PhiMoEConfig(hidden_size=128, intermediate_size=256, num_local_experts=4, num_experts_per_tok=2) @@ -1057,6 +1420,7 @@ def test_phi3_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): sequence_length=sequence_length, quant_bits=quant_bits, onnx_dtype=TensorProto.FLOAT, + block_size=block_size, # Enable block-wise quantization ) hidden_states = torch.randn(batch_size, sequence_length, config.hidden_size).to(torch.float32) @@ -1081,16 +1445,60 @@ def test_phi3_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): (2, 16, 8), ] +# Define test cases for block-wise quantization +swiglu_blockwise_test_cases = [ + (1, 32, 4, 32), # batch_size, sequence_length, quant_bits, block_size + (1, 32, 8, 64), + (2, 16, 4, 32), + (2, 16, 8, 64), +] + @unittest.skipIf(disable_cpu_qmoe_tests, "Skipping qMoE cpu tests") class TestSwigluQMoECPU(unittest.TestCase): @parameterized.expand(swiglu_test_cases) def test_swiglu_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): + # Create unique seed based on test parameters to ensure different inputs for each test + base_seed = 1000 # Different base seed from regular MoE tests + param_hash = hash((batch_size, sequence_length, quant_bits)) + unique_seed = base_seed + abs(param_hash) % 1000 + + torch.manual_seed(unique_seed) + numpy.random.seed(unique_seed) + + test_config = ( + f"batch_size={batch_size}, sequence_length={sequence_length}, quant_bits={quant_bits}, seed={unique_seed}" + ) + print(f"Running SwiGLU test: {test_config}") + + config = SwigluMoeConfig(hidden_size=128, intermediate_size=256, num_local_experts=4, num_experts_per_token=2) + + swiglu_moe = SwigluMoEBlock( + config, + batch_size=batch_size, + sequence_length=sequence_length, + quant_bits=quant_bits, + onnx_dtype=TensorProto.FLOAT, + ) + + hidden_states = torch.randn(batch_size, sequence_length, config.hidden_size).to(torch.float32) + + torch_result = swiglu_moe.forward(hidden_states) + + expected_shape = (batch_size, sequence_length, config.hidden_size) + self.assertEqual(torch_result.shape, expected_shape) + self.assertFalse(torch.isnan(torch_result).any()) + self.assertFalse(torch.isinf(torch_result).any()) + + swiglu_moe.parity_check() + + @parameterized.expand(swiglu_blockwise_test_cases) + def test_swiglu_qmoe_blockwise_parity_cpu(self, batch_size, sequence_length, quant_bits, block_size): torch.manual_seed(42) numpy.random.seed(42) - test_config = f"batch_size={batch_size}, sequence_length={sequence_length}, quant_bits={quant_bits}" - print(f"Running SwiGLU test: {test_config}") + test_config = f"batch_size={batch_size}, sequence_length={sequence_length}, quant_bits={quant_bits}, block_size={block_size}" + print(f"Running SwiGLU block-wise test: {test_config}") config = SwigluMoeConfig(hidden_size=128, intermediate_size=256, num_local_experts=4, num_experts_per_token=2) @@ -1100,6 +1508,7 @@ def test_swiglu_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): sequence_length=sequence_length, quant_bits=quant_bits, onnx_dtype=TensorProto.FLOAT, + block_size=block_size, # Enable block-wise quantization ) hidden_states = torch.randn(batch_size, sequence_length, config.hidden_size).to(torch.float32) @@ -1114,5 +1523,173 @@ def test_swiglu_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): swiglu_moe.parity_check() +@unittest.skipIf(True, "Skipping QMoE CPU benchmark tests") +class TestQMoESwiGLUBenchmark(unittest.TestCase): + """Benchmark tests for QMoE SwiGLU performance measurement.""" + + def test_qmoe_swiglu_throughput_benchmark(self): + """Comprehensive throughput benchmark for QMoE SwiGLU across different configurations.""" + if disable_cpu_qmoe_tests: + self.skipTest("QMoE CPU tests disabled") + + print("\n=== QMoE SwiGLU Throughput Benchmark ===") + + # Test configurations: (name, hidden_size, intermediate_size, num_experts, top_k, quant_bits) + configs = [ + ("Medium-4bit", 2880, 2880, 32, 4, 4), + ("Medium-8bit", 2880, 2880, 32, 4, 8), + ] + + batch_size = 1 + sequence_length = 512 + num_runs = 30 + + results = [] + + for config_name, hidden_size, intermediate_size, num_experts, top_k, quant_bits in configs: + torch.manual_seed(42) + numpy.random.seed(42) + + print(f"\nTesting {config_name}:") + print(f" Hidden: {hidden_size}, Intermediate: {intermediate_size}") + print(f" Experts: {num_experts}, Top-K: {top_k}, Quant: {quant_bits}-bit") + + try: + # Create config and model + config = PhiMoEConfig( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_local_experts=num_experts, + num_experts_per_tok=top_k, + ) + + qmoe_swiglu = PhiMoESparseMoeBlock( + config, + batch_size=batch_size, + sequence_length=sequence_length, + quant_bits=quant_bits, + onnx_dtype=TensorProto.FLOAT, + ) + + # Create test input with fixed sequence length to match ONNX model + full_hidden_states = torch.randn(batch_size, sequence_length, hidden_size).to(torch.float32) + + # For TTFT simulation, we'll measure single forward pass time + # This represents the time to process one token in autoregressive generation + + # Initialize variables + torch_output = None + ort_output = None + + # Warm up with full context + for _ in range(3): + _ = qmoe_swiglu.forward(full_hidden_states) + + # Benchmark PyTorch TTFT (Time to First Token) + # Measure time for a single forward pass (represents token generation time) + torch.manual_seed(42) + + start_time = time.time() + for _ in range(num_runs): + torch_output = qmoe_swiglu.forward(full_hidden_states) + end_time = time.time() + torch_ttft_ms = (end_time - start_time) / num_runs * 1000 + + # Calculate tokens per second (throughput) + # For sequence generation, this represents the rate at which we can generate tokens + torch_tokens_per_sec = 1000.0 / torch_ttft_ms # 1 token / (time_ms / 1000) + + print(f" PyTorch TTFT: {torch_ttft_ms:.3f} ms (per token generation time)") + print(f" PyTorch Throughput: {torch_tokens_per_sec:.1f} tokens/sec") + + # Benchmark ONNX Runtime + ort_ttft_ms = 0 + ort_tokens_per_sec = 0 + speedup = 0 + throughput_ratio = 0 + max_diff = 0 + + model_updated = qmoe_swiglu.recreate_onnx_model() + if model_updated and qmoe_swiglu.ort_sess is not None: + # Warm up ORT with full context + for _ in range(3): + _ = qmoe_swiglu.ort_forward(full_hidden_states) + + torch.manual_seed(42) + + # Measure ONNX Runtime TTFT (Time to First Token) + start_time = time.time() + for _ in range(num_runs): + ort_output = qmoe_swiglu.ort_forward(full_hidden_states) + end_time = time.time() + ort_ttft_ms = (end_time - start_time) / num_runs * 1000 + + # Calculate tokens per second for ONNX Runtime + ort_tokens_per_sec = 1000.0 / ort_ttft_ms # 1 token / (time_ms / 1000) + + speedup = torch_ttft_ms / ort_ttft_ms if ort_ttft_ms > 0 else 0 + throughput_ratio = ort_tokens_per_sec / torch_tokens_per_sec if torch_tokens_per_sec > 0 else 0 + + print(f" ONNX RT TTFT: {ort_ttft_ms:.3f} ms (per token generation time)") + print(f" ONNX RT Throughput: {ort_tokens_per_sec:.1f} tokens/sec") + print(f" TTFT Speedup: {speedup:.2f}x") + print(f" Throughput Gain: {throughput_ratio:.2f}x") + else: + print(" ONNX RT: Not available") + + # Calculate max difference if both outputs available + if torch_output is not None and ort_output is not None: + max_diff = (torch_output.cpu() - ort_output.cpu()).abs().max().item() + print(f" Max diff: {max_diff:.6f}") + + results.append( + { + "config": config_name, + "torch_ttft_ms": torch_ttft_ms, + "torch_tokens_per_sec": torch_tokens_per_sec, + "ort_ttft_ms": ort_ttft_ms, + "ort_tokens_per_sec": ort_tokens_per_sec, + "speedup": speedup, + "throughput_ratio": throughput_ratio, + "max_diff": max_diff, + } + ) + + except Exception as e: + print(f" Error: {e}") + continue + + # Summary + print("\n=== Token Generation Time & Throughput Summary ===") + print( + f"{'Config':<15} {'PT Time':<10} {'PT tok/s':<10} {'ORT Time':<11} {'ORT tok/s':<11} {'Time Gain':<10} {'Throughput':<11} {'Max Diff':<10}" + ) + print("-" * 105) + for result in results: + config = result["config"] + torch_ttft = result["torch_ttft_ms"] + torch_tps = result["torch_tokens_per_sec"] + ort_ttft = result["ort_ttft_ms"] + ort_tps = result["ort_tokens_per_sec"] + speedup = result["speedup"] + throughput_ratio = result["throughput_ratio"] + max_diff = result["max_diff"] + + ort_ttft_str = f"{ort_ttft:.3f}" if ort_ttft > 0 else "N/A" + ort_tps_str = f"{ort_tps:.1f}" if ort_tps > 0 else "N/A" + speedup_str = f"{speedup:.2f}x" if speedup > 0 else "N/A" + throughput_str = f"{throughput_ratio:.2f}x" if throughput_ratio > 0 else "N/A" + + print( + f"{config:<15} {torch_ttft:<10.3f} {torch_tps:<10.1f} {ort_ttft_str:<11} {ort_tps_str:<11} {speedup_str:<10} {throughput_str:<11} {max_diff:<10.6f}" + ) + + print("\nNotes:") + print("- Time: Token generation time in ms (lower is better)") + print("- tok/s: Tokens per second throughput (higher is better)") + print("- Time Gain: ORT speedup for latency (higher is better)") + print("- Throughput: ORT throughput improvement (higher is better)") + + if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index b7a9da8e1b658..8c2928670934a 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -494,6 +494,35 @@ INSTANTIATE_TEST_SUITE_P(CApiTestWithProviders, CApiTestWithProvider, ::testing::Values(0, 1, 2, 3, 4)); +TEST(CApiTest, TestInputPassThroughToOutput) { + const ORTCHAR_T* model_uri = TSTR("testdata/input_propagated_to_output.onnx"); + Ort::Session session(*ort_env, model_uri, Ort::SessionOptions{}); + auto inputs_meminfos = session.GetMemoryInfoForInputs(); + ASSERT_EQ(1U, inputs_meminfos.size()); + auto inputs_epdevices = session.GetEpDeviceForInputs(); + ASSERT_EQ(1U, inputs_epdevices.size()); + auto outputs_meminfos = session.GetMemoryInfoForOutputs(); + ASSERT_EQ(7U, outputs_meminfos.size()); +} + +TEST(CApiTest, TestDanglingInput) { + // Here we test an issue with segments_ids that is an input not consumed by anything + // This kind of model is unlikely to be used in practice but we want to make sure it works + const ORTCHAR_T* model_uri = TSTR("testdata/test_dangling_input_segment_ids.onnx"); + Ort::Session session(*ort_env, model_uri, Ort::SessionOptions{}); + auto inputs_meminfos = session.GetMemoryInfoForInputs(); + ASSERT_EQ(2U, inputs_meminfos.size()); + auto outputs_meminfos = session.GetMemoryInfoForOutputs(); + ASSERT_EQ(2U, outputs_meminfos.size()); + auto inputs_epdevices = session.GetEpDeviceForInputs(); + ASSERT_EQ(2U, inputs_epdevices.size()); + // One of the devices returning is null since the input is not consumed + // there is not a device for it. + const bool null_present = std::any_of(inputs_epdevices.begin(), inputs_epdevices.end(), + [](const auto& device) { return device == nullptr; }); + ASSERT_TRUE(null_present); +} + #if !defined(DISABLE_SPARSE_TENSORS) TEST(CApiTest, SparseOutputModel) { std::vector dense_shape{3, 3}; @@ -505,7 +534,15 @@ TEST(CApiTest, SparseOutputModel) { std::vector ort_inputs; std::vector input_names; const char* const output_names[] = {"values"}; + // This model produces a sparse output from a constant sparse initializer Ort::Session session(*ort_env, SPARSE_OUTPUT_MODEL_URI, Ort::SessionOptions{}); + auto inputs_meminfos = session.GetMemoryInfoForInputs(); + ASSERT_TRUE(inputs_meminfos.empty()); + auto outputs_meminfos = session.GetMemoryInfoForOutputs(); + ASSERT_EQ(1U, outputs_meminfos.size()); + auto inputs_epdevices = session.GetEpDeviceForInputs(); + ASSERT_TRUE(inputs_epdevices.empty()); + auto ort_outputs = session.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(), output_names, 1); ASSERT_EQ(ort_outputs.size(), 1U); diff --git a/onnxruntime/test/testdata/add_mul_add.onnx b/onnxruntime/test/testdata/add_mul_add.onnx new file mode 100644 index 0000000000000..0e2bc1bb9cff9 --- /dev/null +++ b/onnxruntime/test/testdata/add_mul_add.onnx @@ -0,0 +1,28 @@ + +:´ + +A +B +add_outputadd_0"Add +' + +add_output +B +mul_outputmul_0"Mul + + +mul_output +ACadd_1"Add +Main_graphZ +A +  + +Z +B +  + +b +C +  + +B \ No newline at end of file diff --git a/onnxruntime/test/testdata/add_mul_add.py b/onnxruntime/test/testdata/add_mul_add.py new file mode 100644 index 0000000000000..c22a176e065dd --- /dev/null +++ b/onnxruntime/test/testdata/add_mul_add.py @@ -0,0 +1,37 @@ +from onnx import TensorProto, checker, helper, save + +# (A + B) * B + A +graph_proto = helper.make_graph( + nodes=[ + helper.make_node( + "Add", + inputs=["A", "B"], + outputs=["add_output"], + name="add_0", + ), + helper.make_node( + "Mul", + inputs=["add_output", "B"], + outputs=["mul_output"], + name="mul_0", + ), + helper.make_node( + "Add", + inputs=["mul_output", "A"], + outputs=["C"], + name="add_1", + ), + ], + name="Main_graph", + inputs=[ + helper.make_tensor_value_info("A", TensorProto.FLOAT, [3, 2]), + helper.make_tensor_value_info("B", TensorProto.FLOAT, [3, 2]), + ], + outputs=[ + helper.make_tensor_value_info("C", TensorProto.FLOAT, [3, 2]), + ], +) + +model = helper.make_model(graph_proto) +checker.check_model(model, True) +save(model, "add_mul_add.onnx") diff --git a/onnxruntime/test/testdata/input_propagated_to_output.onnx b/onnxruntime/test/testdata/input_propagated_to_output.onnx new file mode 100644 index 0000000000000..28d805ce87868 Binary files /dev/null and b/onnxruntime/test/testdata/input_propagated_to_output.onnx differ diff --git a/onnxruntime/test/testdata/input_propagated_to_output.py b/onnxruntime/test/testdata/input_propagated_to_output.py new file mode 100644 index 0000000000000..d548f40507d47 --- /dev/null +++ b/onnxruntime/test/testdata/input_propagated_to_output.py @@ -0,0 +1,113 @@ +""" +Run this script to recreate the original onnx model. +Example usage: +python input_propagated_to_output.py input_propagated_to_output.onnx +""" + +import sys + +import numpy as np +import onnx + + +def order_repeated_field(repeated_proto, key_name, order): + order = list(order) + repeated_proto.sort(key=lambda x: order.index(getattr(x, key_name))) + + +def make_node(op_type, inputs, outputs, name=None, doc_string=None, domain=None, **kwargs): + node = onnx.helper.make_node(op_type, inputs, outputs, name, doc_string, domain, **kwargs) + if doc_string == "": + node.doc_string = "" + order_repeated_field(node.attribute, "name", kwargs.keys()) + return node + + +def make_graph(*args, doc_string=None, **kwargs): + graph = onnx.helper.make_graph(*args, doc_string=doc_string, **kwargs) + if doc_string == "": + graph.doc_string = "" + return graph + + +W1 = np.array( + [ + [[[0.3258337378501892]], [[0.1461111307144165]], [[-0.4239698648452759]]], + [[[0.14769716560840607]], [[0.20565544068813324]], [[-0.5241780877113342]]], + [[[0.07987150549888611]], [[-0.17475983500480652]], [[0.005230882670730352]]], + ], + dtype=np.float32, +) + +B1 = np.array( + [-0.3170531392097473, -0.2701416313648224, -0.14249320328235626], + dtype=np.float32, +) + +W3 = np.array( + [ + [[[0.14025720953941345]], [[0.1433156430721283]], [[-0.1403128057718277]]], + [[[-0.07530076801776886]], [[0.11853527277708054]], [[-0.19437682628631592]]], + [[[0.5786639451980591]], [[-0.28565627336502075]], [[0.9048876166343689]]], + ], + dtype=np.float32, +) + +B3 = np.array( + [-0.13307525217533112, 0.5522456169128418, 0.6449958086013794], + dtype=np.float32, +) + +W5 = np.array( + [ + [[[-0.08959630876779556]], [[0.07607565075159073]], [[0.24446037411689758]]], + [[[-0.06293385475873947]], [[-0.41520264744758606]], [[-0.83400559425354]]], + [[[-0.031176576390862465]], [[-0.04187283664941788]], [[-0.439873069524765]]], + ], + dtype=np.float32, +) + +B5 = np.array( + [0.5949633717536926, -0.40198755264282227, -0.20182392001152039], + dtype=np.float32, +) + +model = onnx.helper.make_model( + opset_imports=[onnx.helper.make_operatorsetid("", 14)], + ir_version=7, + graph=make_graph( + name="input_propagated_to_output", + inputs=[ + onnx.helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, shape=[1, 3, 1, 3]), + ], + outputs=[ + onnx.helper.make_tensor_value_info("X6", onnx.TensorProto.FLOAT, shape=[1, 3, 1, 3]), + onnx.helper.make_tensor_value_info("X1", onnx.TensorProto.FLOAT, shape=[1, 3, 1, 3]), + onnx.helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, shape=[1, 3, 1, 3]), + onnx.helper.make_tensor_value_info("X2", onnx.TensorProto.FLOAT, shape=[1, 3, 1, 3]), + onnx.helper.make_tensor_value_info("X4", onnx.TensorProto.FLOAT, shape=[1, 3, 1, 3]), + onnx.helper.make_tensor_value_info("X3", onnx.TensorProto.FLOAT, shape=[1, 3, 1, 3]), + onnx.helper.make_tensor_value_info("X5", onnx.TensorProto.FLOAT, shape=[1, 3, 1, 3]), + ], + initializer=[ + onnx.numpy_helper.from_array(W1, name="W1"), + onnx.numpy_helper.from_array(W3, name="W3"), + onnx.numpy_helper.from_array(W5, name="W5"), + onnx.numpy_helper.from_array(B1, name="B1"), + onnx.numpy_helper.from_array(B3, name="B3"), + onnx.numpy_helper.from_array(B5, name="B5"), + ], + nodes=[ + make_node("Relu", inputs=["input"], outputs=["X1"], name="Relu1"), + make_node("Conv", inputs=["X1", "W1", "B1"], outputs=["X2"], name="Conv1"), + make_node("Relu", inputs=["X2"], outputs=["X3"], name="Relu2"), + make_node("Conv", inputs=["X3", "W3", "B3"], outputs=["X4"], name="Conv2"), + make_node("Conv", inputs=["X1", "W5", "B5"], outputs=["X5"], name="Conv3"), + make_node("Add", inputs=["X4", "X5"], outputs=["X6"], name="Add"), + ], + ), +) + +if __name__ == "__main__" and len(sys.argv) == 2: + _, out_path = sys.argv + onnx.save(model, out_path) diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index 43f6e480672ba..a04aafecbc81a 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -32,6 +32,31 @@ "^test_adagrad", "^test_adagrad_multiple", "^test_attention_3d.*", // wrong expected values in onnx==1.18.0, fixed in 1.19.0 + "^test_attention_4d_diff_heads_mask4d_padded_kv*", // pending onnx update + "^test_attention_3d_gqa*", // pending onnx update + "^test_attention_3d_gqa_causal", // pending onnx update + "^test_attention_3d_gqa_scaled", // pending onnx update + "^test_attention_3d_gqa_softcap", // pending onnx update + "^test_attention_3d_gqa_with_past_and_present", // pending onnx update + "^test_attention_4d_gqa*", // pending onnx update + "^test_attention_4d_gqa_causal", // pending onnx update + "^test_attention_4d_gqa_scaled", // pending onnx update + "^test_attention_4d_gqa_softcap", // pending onnx update + "^test_attention_4d_gqa_with_past_and_present", // pending onnx update + "^test_attention_*causal*", // pending onnx update + "^test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_causal*", // pending onnx update + "^test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal*", // pending onnx update + "^test_attention_4d_attn_mask_3d_causal_expanded*", // pending onnx update + "^test_attention_4d_fp16*", // precision issue: 1 / 192 mismatched elements + "^test_attention_4d_fp16_expanded*", // precision issue: 3 / 192 mismatched elements + "^test_l2normalization*", // LpNormalization(22) not implemented + "^test_l1normalization*", // LpNormalization(22) not implemented + "^test_lpnormalization*", // LpNormalization(22) not implemented + "^test_tensorscatter*", // TensorScatter(24) not implemented + "^test_castlike_no_saturate_FLOAT_to_FLOAT8*", // ORT does not support ml_dtypes + "^test_castlike_UINT4_to*", // ORT does not support ml_dtypes + "^test_castlike_INT4_to*", // ORT does not support ml_dtypes + "^test_cast_e8m0_*", // ORT does not support float8e8m0 "^test_batchnorm_epsilon_training_mode", "^test_batchnorm_example_training_mode", "^test_col2im_pads", // still one wrong value coming from the backtest example diff --git a/onnxruntime/test/testdata/test_dangling_input_segment_ids.onnx b/onnxruntime/test/testdata/test_dangling_input_segment_ids.onnx new file mode 100644 index 0000000000000..a83c21030ad67 Binary files /dev/null and b/onnxruntime/test/testdata/test_dangling_input_segment_ids.onnx differ diff --git a/onnxruntime/test/testdata/test_dangling_input_segment_ids.py b/onnxruntime/test/testdata/test_dangling_input_segment_ids.py new file mode 100644 index 0000000000000..0cd4741bbbf36 --- /dev/null +++ b/onnxruntime/test/testdata/test_dangling_input_segment_ids.py @@ -0,0 +1,136 @@ +""" +Run this script to recreate the original onnx model. +Example usage: +python test_dangling_input_segment_ids.py test_dangling_input_segment_ids.onnx +""" + +import sys + +import numpy as np +import onnx + + +def order_repeated_field(repeated_proto, key_name, order): + order = list(order) + repeated_proto.sort(key=lambda x: order.index(getattr(x, key_name))) + + +def make_node(op_type, inputs, outputs, name=None, doc_string=None, domain=None, **kwargs): + node = onnx.helper.make_node(op_type, inputs, outputs, name, doc_string, domain, **kwargs) + if doc_string == "": + node.doc_string = "" + order_repeated_field(node.attribute, "name", kwargs.keys()) + return node + + +def make_graph(*args, doc_string=None, **kwargs): + graph = onnx.helper.make_graph(*args, doc_string=doc_string, **kwargs) + if doc_string == "": + graph.doc_string = "" + return graph + + +WORD_EMBED = np.array( + [ + [0.31524479389190674, 0.8928887248039246, 0.5778571963310242, 0.18401020765304565], + [0.7879292368888855, 0.6120311617851257, 0.05390927195549011, 0.4201936721801758], + [0.6790688633918762, 0.9186017513275146, 0.0004020248888991773, 0.976759135723114], + [0.3765803277492523, 0.973783552646637, 0.6047161221504211, 0.8288457989692688], + [0.5747115015983582, 0.6280761957168579, 0.2855762839317322, 0.5868333578109741], + [0.750021755695343, 0.8583138585090637, 0.7550821900367737, 0.698057234287262], + [0.8644794225692749, 0.3226810097694397, 0.6707887649536133, 0.4508739411830902], + [0.38210275769233704, 0.4108113646507263, 0.401479572057724, 0.31738394498825073], + [0.6219193935394287, 0.4302472770214081, 0.9738020896911621, 0.6778008937835693], + [0.1985698938369751, 0.42670100927352905, 0.3433462381362915, 0.7976388335227966], + [0.8799982666969299, 0.9038419723510742, 0.6627197861671448, 0.2702082693576813], + [0.25236669182777405, 0.8548979163169861, 0.5277146697044373, 0.8021610975265503], + [0.57248854637146, 0.7331425547599792, 0.5190116167068481, 0.7708839178085327], + [0.5688579678535461, 0.4657098650932312, 0.3426889181137085, 0.06820935010910034], + [0.3779241740703583, 0.07962607592344284, 0.9828171133995056, 0.18161284923553467], + [0.8118587136268616, 0.8749616742134094, 0.6884132623672485, 0.5694944262504578], + [0.16097143292427063, 0.46688002347946167, 0.34517204761505127, 0.22503995895385742], + [0.5925118923187256, 0.31226983666419983, 0.9163055419921875, 0.9096355438232422], + [0.257118284702301, 0.11089129745960236, 0.19296273589134216, 0.4995841681957245], + [0.7285856604576111, 0.20819443464279175, 0.2480335533618927, 0.8516718745231628], + [0.4158487319946289, 0.6166850924491882, 0.23366613686084747, 0.10196726024150848], + [0.5158570408821106, 0.47714099287986755, 0.15267165005207062, 0.6218062043190002], + [0.5440101027488708, 0.654137372970581, 0.1445455402135849, 0.7515278458595276], + [0.22204914689064026, 0.5193518400192261, 0.7852960228919983, 0.022330427542328835], + [0.32436245679855347, 0.8729223608970642, 0.8447096347808838, 0.5384405851364136], + [0.8666082620620728, 0.9498059749603271, 0.8264070153236389, 0.8541154265403748], + [0.09874340146780014, 0.651304304599762, 0.703516960144043, 0.6102408170700073], + [0.7996152639389038, 0.034571219235658646, 0.7702387571334839, 0.7317286133766174], + [0.25969839096069336, 0.25706928968429565, 0.6323032975196838, 0.3452974557876587], + [0.796588659286499, 0.4461462199687958, 0.7827494144439697, 0.9904717803001404], + [0.30024832487106323, 0.143005833029747, 0.9013084173202515, 0.5415593981742859], + [0.9747403860092163, 0.6366044282913208, 0.9939129948616028, 0.5460708141326904], + ], + dtype=np.float32, +) + +POS_EMBED = np.array( + [ + [0.5264259576797485, 0.13542790710926056, 0.3557051718235016, 0.026218567043542862], + [0.16039517521858215, 0.7456371784210205, 0.030399689450860023, 0.36654308438301086], + [0.8623462319374084, 0.6926777362823486, 0.6909421682357788, 0.18863679468631744], + [0.4419042766094208, 0.5815774202346802, 0.9897516965866089, 0.20390622317790985], + [0.24773290753364563, 0.2621730864048004, 0.7501724362373352, 0.4569753408432007], + [0.056929439306259155, 0.508516252040863, 0.21196016669273376, 0.7986042499542236], + [0.29733139276504517, 0.027606012299656868, 0.5934324264526367, 0.8438404202461243], + [0.3810161352157593, 0.7498583197593689, 0.5111414790153503, 0.5409517884254456], + [0.9594343304634094, 0.803960919380188, 0.032323066145181656, 0.7093872427940369], + [0.46500149369239807, 0.9475489258766174, 0.22143273055553436, 0.26707202196121216], + [0.08147396147251129, 0.42861881852149963, 0.10901876538991928, 0.6337867379188538], + [0.8029632568359375, 0.6968004703521729, 0.7662113904953003, 0.34245410561561584], + [0.845851480960846, 0.4287687838077545, 0.824009895324707, 0.6264961361885071], + [0.14342305064201355, 0.07838690280914307, 0.018332643434405327, 0.0667250007390976], + [0.458583801984787, 0.11334192007780075, 0.0277833491563797, 0.7548614740371704], + [0.394850492477417, 0.7469384670257568, 0.45240482687950134, 0.4500867426395416], + ], + dtype=np.float32, +) + +model = onnx.helper.make_model( + opset_imports=[onnx.helper.make_operatorsetid("", 14), onnx.helper.make_operatorsetid("com.microsoft", 1)], + ir_version=7, + graph=make_graph( + name="embed_layernorm_graph", + inputs=[ + onnx.helper.make_tensor_value_info("input_ids", onnx.TensorProto.INT32, shape=[1, 4]), + onnx.helper.make_tensor_value_info("segment_ids", onnx.TensorProto.INT32, shape=[1, 4]), + ], + outputs=[ + onnx.helper.make_tensor_value_info("layernorm_out", onnx.TensorProto.FLOAT, shape=[1, 4, 4]), + onnx.helper.make_tensor_value_info("mask_index_out", onnx.TensorProto.INT32, shape=[1]), + ], + initializer=[ + onnx.numpy_helper.from_array(WORD_EMBED, name="word_embed"), + onnx.numpy_helper.from_array(POS_EMBED, name="pos_embed"), + onnx.numpy_helper.from_array( + np.array( + [0.6185135841369629, 0.010364261455833912, 0.5386272668838501, 0.0030179566238075495], + dtype="float32", + ), + name="gamma", + ), + onnx.numpy_helper.from_array( + np.array( + [0.9511938095092773, 0.9054020047187805, 0.7959669232368469, 0.9152743220329285], dtype="float32" + ), + name="beta", + ), + ], + nodes=[ + make_node( + "EmbedLayerNormalization", + inputs=["input_ids", "", "word_embed", "pos_embed", "", "gamma", "beta"], + outputs=["layernorm_out", "mask_index_out"], + domain="com.microsoft", + ) + ], + ), +) + +if __name__ == "__main__" and len(sys.argv) == 2: + _, out_path = sys.argv + onnx.save(model, out_path) diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 22da226bbb7d9..b625e3ca67db5 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -891,7 +891,7 @@ def generate_build_tree( # * Leave disabled if "no_kleidiai" argument was specified. # * Enable if the target is Android and args.android_abi contains arm64* # * Enable for a Windows cross compile build if compile target is an Arm one. - # * Finally enable if platform.machine contains "arm64". This should cover the following cases: + # * Finally enable if platform.machine contains "arm64" and not a WebAssembly build. This should cover the following cases: # * Linux on Arm # * MacOs (case must be ignored) # * TODO Delegate responsibility for Onnxruntime_USE_KLEIDIAI = ON to CMake logic @@ -899,7 +899,7 @@ def generate_build_tree( if ( (args.android and "arm64" in args.android_abi.lower()) or (is_windows() and (args.arm64 or args.arm64ec or args.arm) and platform.architecture()[0] != "AMD64") - or ("arm64" in platform.machine().lower()) + or ("arm64" in platform.machine().lower() and not args.build_wasm) ): cmake_args += ["-Donnxruntime_USE_KLEIDIAI=ON"]