diff --git a/VERSION_NUMBER b/VERSION_NUMBER index a6c2798a482eb..49e0a31d4964d 100644 --- a/VERSION_NUMBER +++ b/VERSION_NUMBER @@ -1 +1 @@ -1.23.0 +1.23.1 diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 6aad71e40b2a8..b23365e99c2d7 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1800,6 +1800,7 @@ endif() if (WIN32 AND onnxruntime_BUILD_SHARED_LIB AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND NOT onnxruntime_MINIMAL_BUILD) + # example_plugin_ep file(GLOB onnxruntime_autoep_test_library_src "${TEST_SRC_DIR}/autoep/library/*.h" "${TEST_SRC_DIR}/autoep/library/*.cc") onnxruntime_add_shared_library_module(example_plugin_ep ${onnxruntime_autoep_test_library_src}) @@ -1822,6 +1823,9 @@ if (WIN32 AND onnxruntime_BUILD_SHARED_LIB AND set_property(TARGET example_plugin_ep APPEND_STRING PROPERTY LINK_FLAGS ${ONNXRUNTIME_AUTOEP_LIB_LINK_FLAG}) + set_target_properties(example_plugin_ep PROPERTIES FOLDER "ONNXRuntimeTest") + source_group(TREE ${TEST_SRC_DIR} FILES ${onnxruntime_autoep_test_library_src}) + # test library file(GLOB onnxruntime_autoep_test_SRC "${ONNXRUNTIME_AUTOEP_TEST_SRC_DIR}/*.h" "${ONNXRUNTIME_AUTOEP_TEST_SRC_DIR}/*.cc") diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 660c63d056335..b8385a3352278 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -253,6 +253,8 @@ Do not modify directly.* |||[9, 12]|**T** = tensor(float)| |||[1, 8]|**T** = tensor(float)| |MelWeightMatrix|*in* num_mel_bins:**T1**
*in* dft_length:**T1**
*in* sample_rate:**T1**
*in* lower_edge_hertz:**T2**
*in* upper_edge_hertz:**T2**
*out* output:**T3**|17+|**T1** = tensor(int32), tensor(int64)
**T2** = tensor(float)
**T3** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|MemcpyFromHost|*in* X:**T**
*out* Y:**T**|1+|**T** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|MemcpyToHost|*in* X:**T**
*out* Y:**T**|1+|**T** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Min|*in* data_0:**T**
*out* min:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |||[8, 11]|**T** = tensor(double), tensor(float)| diff --git a/docs/python/README.rst b/docs/python/README.rst index fdef200c1d0de..c23c194ed8132 100644 --- a/docs/python/README.rst +++ b/docs/python/README.rst @@ -8,6 +8,11 @@ For more information on ONNX Runtime, please see `aka.ms/onnxruntime >; /** \brief Wrapper around ::OrtSyncStream * */ -struct SyncStream : detail::Base { - explicit SyncStream(std::nullptr_t) {} ///< Create an empty SyncStream object, must be assigned a valid one to be used - explicit SyncStream(OrtSyncStream* p) : Base{p} {} ///< Take ownership of a pointer created by C API - void* GetHandle() const; ///< Wraps SyncStream_GetHandle + +namespace detail { +template +struct SyncStreamImpl : Base { + using B = Base; + using B::B; + // For some reason this is not a const method on the stream + void* GetHandle(); ///< Wraps SyncStream_GetHandle }; +} // namespace detail + +struct SyncStream : detail::SyncStreamImpl { + ///< Create an empty SyncStream object, must be assigned a valid one to be used + explicit SyncStream(std::nullptr_t) {} + ///< Take ownership of a pointer created by C API + explicit SyncStream(OrtSyncStream* p) : SyncStreamImpl{p} {} +}; + +using UnownedSyncStream = detail::SyncStreamImpl>; namespace detail { template diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 59979189eed0f..cb6448ad12a81 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -669,9 +669,12 @@ inline void KeyValuePairs::Remove(const char* key) { GetApi().RemoveKeyValuePair(this->p_, key); } -inline void* SyncStream::GetHandle() const { +namespace detail { +template +inline void* SyncStreamImpl::GetHandle() { return GetApi().SyncStream_GetHandle(this->p_); } +} // namespace detail namespace detail { template @@ -1582,11 +1585,13 @@ inline std::vector ConstSessionImpl::GetMemoryInfoForInputs( auto num_inputs = GetInputCount(); std::vector mem_infos; - mem_infos.resize(num_inputs); + if (num_inputs > 0) { + mem_infos.resize(num_inputs); - ThrowOnError(GetApi().SessionGetMemoryInfoForInputs(this->p_, - reinterpret_cast(mem_infos.data()), - num_inputs)); + ThrowOnError(GetApi().SessionGetMemoryInfoForInputs(this->p_, + reinterpret_cast(mem_infos.data()), + num_inputs)); + } return mem_infos; } @@ -1598,11 +1603,13 @@ inline std::vector ConstSessionImpl::GetMemoryInfoForOutputs auto num_outputs = GetOutputCount(); std::vector mem_infos; - mem_infos.resize(num_outputs); + if (num_outputs > 0) { + mem_infos.resize(num_outputs); - ThrowOnError(GetApi().SessionGetMemoryInfoForOutputs(this->p_, - reinterpret_cast(mem_infos.data()), - num_outputs)); + ThrowOnError(GetApi().SessionGetMemoryInfoForOutputs(this->p_, + reinterpret_cast(mem_infos.data()), + num_outputs)); + } return mem_infos; } @@ -1631,12 +1638,12 @@ template inline std::vector ConstSessionImpl::GetEpDeviceForInputs() const { auto num_inputs = GetInputCount(); std::vector input_devices; - input_devices.resize(num_inputs); - - ThrowOnError(GetApi().SessionGetEpDeviceForInputs(this->p_, - reinterpret_cast(input_devices.data()), - num_inputs)); - + if (num_inputs > 0) { + input_devices.resize(num_inputs); + ThrowOnError(GetApi().SessionGetEpDeviceForInputs(this->p_, + reinterpret_cast(input_devices.data()), + num_inputs)); + } return input_devices; } diff --git a/js/common/lib/version.ts b/js/common/lib/version.ts index 994eb6f4300c1..f00f4ec8dee50 100644 --- a/js/common/lib/version.ts +++ b/js/common/lib/version.ts @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -export const version = '1.23.0'; +export const version = '1.23.1'; diff --git a/js/common/package-lock.json b/js/common/package-lock.json index 706f8b46a3ad4..9ef468a229788 100644 --- a/js/common/package-lock.json +++ b/js/common/package-lock.json @@ -1,12 +1,12 @@ { "name": "onnxruntime-common", - "version": "1.23.0", + "version": "1.23.1", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "onnxruntime-common", - "version": "1.23.0", + "version": "1.23.1", "license": "MIT", "devDependencies": { "typedoc": "^0.25.7" diff --git a/js/common/package.json b/js/common/package.json index a0eff9095e6d7..200aff42f8fca 100644 --- a/js/common/package.json +++ b/js/common/package.json @@ -2,7 +2,7 @@ "license": "MIT", "type": "module", "name": "onnxruntime-common", - "version": "1.23.0", + "version": "1.23.1", "repository": { "url": "https://github.com/Microsoft/onnxruntime.git", "type": "git" diff --git a/js/node/lib/version.ts b/js/node/lib/version.ts index 994eb6f4300c1..f00f4ec8dee50 100644 --- a/js/node/lib/version.ts +++ b/js/node/lib/version.ts @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -export const version = '1.23.0'; +export const version = '1.23.1'; diff --git a/js/node/package-lock.json b/js/node/package-lock.json index bd7e6cc1966c7..0a65eab39df70 100644 --- a/js/node/package-lock.json +++ b/js/node/package-lock.json @@ -1,12 +1,12 @@ { "name": "onnxruntime-node", - "version": "1.23.0", + "version": "1.23.1", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "onnxruntime-node", - "version": "1.23.0", + "version": "1.23.1", "hasInstallScript": true, "license": "MIT", "os": [ @@ -30,7 +30,7 @@ }, "../common": { "name": "onnxruntime-common", - "version": "1.23.0", + "version": "1.23.1", "license": "MIT", "devDependencies": { "typedoc": "^0.25.7" diff --git a/js/node/package.json b/js/node/package.json index 5520a48aa124a..1f29f2354b0d7 100644 --- a/js/node/package.json +++ b/js/node/package.json @@ -11,7 +11,7 @@ 6 ] }, - "version": "1.23.0", + "version": "1.23.1", "dependencies": { "adm-zip": "^0.5.16", "global-agent": "^3.0.0", diff --git a/js/node/script/install-metadata-versions.js b/js/node/script/install-metadata-versions.js index 3147f90904e7a..23df0b7ac96ed 100644 --- a/js/node/script/install-metadata-versions.js +++ b/js/node/script/install-metadata-versions.js @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -module.exports = { nuget: [{ feed: 'nuget', version: '1.23.0' }] }; +module.exports = { nuget: [{ feed: 'nuget', version: '1.23.1' }] }; diff --git a/js/react_native/lib/version.ts b/js/react_native/lib/version.ts index 994eb6f4300c1..f00f4ec8dee50 100644 --- a/js/react_native/lib/version.ts +++ b/js/react_native/lib/version.ts @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -export const version = '1.23.0'; +export const version = '1.23.1'; diff --git a/js/react_native/package-lock.json b/js/react_native/package-lock.json index ec2147b2cc4ba..f681b9166da98 100644 --- a/js/react_native/package-lock.json +++ b/js/react_native/package-lock.json @@ -1,12 +1,12 @@ { "name": "onnxruntime-react-native", - "version": "1.23.0", + "version": "1.23.1", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "onnxruntime-react-native", - "version": "1.23.0", + "version": "1.23.1", "license": "MIT", "dependencies": { "buffer": "^6.0.3", @@ -31,7 +31,7 @@ }, "../common": { "name": "onnxruntime-common", - "version": "1.23.0", + "version": "1.23.1", "license": "MIT", "devDependencies": { "typedoc": "^0.25.7" diff --git a/js/react_native/package.json b/js/react_native/package.json index 7a5ee35bdb25a..a88f5cf267aed 100644 --- a/js/react_native/package.json +++ b/js/react_native/package.json @@ -37,7 +37,7 @@ "registry": "https://registry.npmjs.org/" }, "source": "lib/index", - "version": "1.23.0", + "version": "1.23.1", "main": "dist/commonjs/index", "homepage": "https://github.com/microsoft/onnxruntime/blob/main/js/react_native/README.md", "files": [ diff --git a/js/web/lib/version.ts b/js/web/lib/version.ts index 994eb6f4300c1..f00f4ec8dee50 100644 --- a/js/web/lib/version.ts +++ b/js/web/lib/version.ts @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -export const version = '1.23.0'; +export const version = '1.23.1'; diff --git a/js/web/package-lock.json b/js/web/package-lock.json index eabb198e97177..74776abb25bd5 100644 --- a/js/web/package-lock.json +++ b/js/web/package-lock.json @@ -1,12 +1,12 @@ { "name": "onnxruntime-web", - "version": "1.23.0", + "version": "1.23.1", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "onnxruntime-web", - "version": "1.23.0", + "version": "1.23.1", "license": "MIT", "dependencies": { "flatbuffers": "^25.1.24", @@ -50,7 +50,7 @@ }, "../common": { "name": "onnxruntime-common", - "version": "1.23.0", + "version": "1.23.1", "license": "MIT", "devDependencies": { "typedoc": "^0.25.7" diff --git a/js/web/package.json b/js/web/package.json index 425aa88035424..db20202b4f24e 100644 --- a/js/web/package.json +++ b/js/web/package.json @@ -7,7 +7,7 @@ "type": "git" }, "author": "fs-eire", - "version": "1.23.0", + "version": "1.23.1", "jsdelivr": "dist/ort.min.js", "dependencies": { "flatbuffers": "^25.1.24", diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index 550502cf3bc48..ac25159802092 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -8,7 +8,7 @@ or the `Github project `_. """ -__version__ = "1.23.0" +__version__ = "1.23.1" __author__ = "Microsoft" # we need to do device version validation (for example to check Cuda version for an onnxruntime-training package). @@ -31,14 +31,17 @@ OrtAllocatorType, # noqa: F401 OrtArenaCfg, # noqa: F401 OrtCompileApiFlags, # noqa: F401 + OrtDeviceMemoryType, # noqa: F401 OrtEpDevice, # noqa: F401 OrtExecutionProviderDevicePolicy, # noqa: F401 OrtExternalInitializerInfo, # noqa: F401 OrtHardwareDevice, # noqa: F401 OrtHardwareDeviceType, # noqa: F401 OrtMemoryInfo, # noqa: F401 + OrtMemoryInfoDeviceType, # noqa: F401 OrtMemType, # noqa: F401 OrtSparseFormat, # noqa: F401 + OrtSyncStream, # noqa: F401 RunOptions, # noqa: F401 SessionIOBinding, # noqa: F401 SessionOptions, # noqa: F401 @@ -78,6 +81,7 @@ OrtDevice, # noqa: F401 OrtValue, # noqa: F401 SparseTensor, # noqa: F401 + copy_tensors, # noqa: F401 ) # TODO: thiagofc: Temporary experimental namespace for new PyTorch front-end diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 98cc2158eb0d0..01ba492eb166e 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -226,13 +226,22 @@ Status SessionState::PopulateKernelCreateInfo(const KernelRegistryManager& kerne for (auto& node : graph_.Nodes()) { const KernelCreateInfo* kci = nullptr; auto status = kernel_registry_manager.SearchKernelRegistry(node, logger_, &kci); - if (!status.IsOK() && saving_ort_format) { - // if we didn't find the kernel and are saving to ORT format an EP that compiles nodes is enabled. - // in that case we assigned the node to that EP but do not compile it into a fused node. - // this keeps the original node and prevents level 2 and level 3 optimizers from modifying it. - // we now revert to the CPU EP kernel as a fallback. - // at runtime when the model is loaded in a minimal build, the compiling EP will replace this node if possible. - // if that's not possible for some reason we can fallback to the CPU EP implementation. + + // There are two cases where we allow fallback to CPU EP kernels: + // + // 1. if we didn't find the kernel and are saving to ORT format an EP that compiles nodes is enabled. + // in that case we assigned the node to that EP but do not compile it into a fused node. + // this keeps the original node and prevents level 2 and level 3 optimizers from modifying it. + // we now revert to the CPU EP kernel as a fallback. + // at runtime when the model is loaded in a minimal build, the compiling EP will replace this node if possible. + // if that's not possible for some reason we can fallback to the CPU EP implementation. + // + // 2. If the node is a memcpy node. + // EPs may provide their own memcpy kernels. The CPU EP provides a generic version to fall back to if the EP does + // not provide one. + const bool allow_cpu_ep_kernel_fallback = saving_ort_format || utils::IsMemcpyNode(node); + + if (!status.IsOK() && allow_cpu_ep_kernel_fallback) { node.SetExecutionProviderType(kCpuExecutionProvider); status = kernel_registry_manager.SearchKernelRegistry(node, logger_, &kci); } diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index 2c0a51f0bfdbc..ca64c7c7cae89 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -46,22 +46,13 @@ void DestroyStrings(void* p_data, int64_t elements) { ptr[i].~string(); } -bool ProviderIsCpuBased(const std::string& provider_type) { - return provider_type == onnxruntime::kCpuExecutionProvider || - provider_type == onnxruntime::kDnnlExecutionProvider || - provider_type == onnxruntime::kVitisAIExecutionProvider || - provider_type == onnxruntime::kOpenVINOExecutionProvider || - provider_type == onnxruntime::kNnapiExecutionProvider || - provider_type == onnxruntime::kVSINPUExecutionProvider || - provider_type == onnxruntime::kAclExecutionProvider || - provider_type == onnxruntime::kArmNNExecutionProvider || - provider_type == onnxruntime::kRknpuExecutionProvider || - provider_type == onnxruntime::kCoreMLExecutionProvider || - provider_type == onnxruntime::kSnpeExecutionProvider || - provider_type == onnxruntime::kQnnExecutionProvider || - provider_type == onnxruntime::kXnnpackExecutionProvider || - provider_type == onnxruntime::kAzureExecutionProvider || - provider_type == onnxruntime::utils::kInternalTestingExecutionProvider; +bool ProviderIsCpuBased(const IExecutionProvider& provider) { + return provider.GetDevice().Type() == OrtDevice::CPU; +} + +bool IsMemcpyNode(const Node& node) { + return node.Domain() == kOnnxDomain && + (node.OpType() == "MemcpyFromHost" || node.OpType() == "MemcpyToHost"); } static common::Status AllocateHelper(const AllocatorPtr& allocator, @@ -210,7 +201,7 @@ static Status BatchOrCopyMLValue(const SessionState& session_state, static bool HaveCpuExecutionProvidersOnly(const ExecutionProviders& execution_providers) { for (const auto& execution_provider : execution_providers) { - if (!ProviderIsCpuBased(execution_provider->Type())) { + if (!ProviderIsCpuBased(*execution_provider)) { return false; } } diff --git a/onnxruntime/core/framework/utils.h b/onnxruntime/core/framework/utils.h index 6b5c404e26b7f..796a17b0406da 100644 --- a/onnxruntime/core/framework/utils.h +++ b/onnxruntime/core/framework/utils.h @@ -52,12 +52,10 @@ void DestroyStrings(void* p_data, int64_t elements); const std::string& GetNodeInputProviderType(const SessionState::NodeInfo& info); -// EP used for internal testing. We define it here as it's used in ProviderIsCpuBased, but we don't want -// it to be in the public header include/onnxruntime/core/graph/constants.h as it's purely internal. -constexpr const char* kInternalTestingExecutionProvider = "InternalTestingExecutionProvider"; - // return true if the execution provider is CPU based (meaning no copies to device are required) -bool ProviderIsCpuBased(const std::string& provider_type); +bool ProviderIsCpuBased(const IExecutionProvider& provider); + +bool IsMemcpyNode(const Node& node); common::Status CopyOneInputAcrossDevices(const SessionState& session_state, const std::string& input_name, const OrtValue& orig_mlvalue, OrtValue& new_mlvalue); diff --git a/onnxruntime/core/optimizer/transformer_memcpy.cc b/onnxruntime/core/optimizer/transformer_memcpy.cc index cc7682b2b418d..9d49c16391f78 100644 --- a/onnxruntime/core/optimizer/transformer_memcpy.cc +++ b/onnxruntime/core/optimizer/transformer_memcpy.cc @@ -1,7 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "transformer_memcpy.h" +#include "core/optimizer/transformer_memcpy.h" + #include "core/common/logging/logging.h" #include "core/framework/kernel_registry_manager.h" #include "core/framework/execution_providers.h" @@ -12,18 +13,39 @@ using namespace ONNX_NAMESPACE; namespace onnxruntime { +static ProviderTypeToProviderMap GetProvidersByType( + const InlinedVector>& providers) { + ProviderTypeToProviderMap providers_by_type{}; + for (const auto provider : providers) { + providers_by_type.emplace(provider->Type(), provider); + } + return providers_by_type; +} + +MemcpyTransformer::MemcpyTransformer(InlinedVector> providers, + const KernelRegistryManager& registry_manager) + : GraphTransformer("MemcpyTransformer"), + providers_(std::move(providers)), + providers_by_type_(GetProvidersByType(providers_)), + registry_manager_(std::cref(registry_manager)) { +} + // implements MemCpy node insertion in graph transform // note that GraphTransformer::Apply() is supposed to be stateless, so this cannot derive from GraphTransformer class TransformerMemcpyImpl { public: - TransformerMemcpyImpl(onnxruntime::Graph& graph, const std::string& provider) - : graph_(graph), provider_(provider) {} + TransformerMemcpyImpl(onnxruntime::Graph& graph, const IExecutionProvider& provider, + const ProviderTypeToProviderMap& providers_by_type) + : graph_(graph), provider_(provider), providers_by_type_(providers_by_type) { + } bool ModifyGraph(const KernelRegistryManager& schema_registries, const logging::Logger& logger, int& copy_node_counter); private: + bool IsNodeCompatibleWithProvider(const onnxruntime::Node& node) const; + void ProcessDefs(onnxruntime::Node& node, const KernelRegistryManager& kernel_registries, InitializedTensorSet& initializers_consumed, @@ -31,7 +53,9 @@ class TransformerMemcpyImpl { void BuildDefsMapping(const onnxruntime::NodeArg* arg, const KernelRegistryManager& kernel_registries, const logging::Logger& logger); - void AddCopyNode(onnxruntime::NodeArg* arg, bool is_input, const logging::Logger& logger); + void AddCopyNode(onnxruntime::NodeArg* arg, + bool is_input, + const logging::Logger& logger); bool ProcessInitializers(const KernelRegistryManager& kernel_registries, const InitializedTensorSet& initializers_consumed, const logging::Logger& logger); @@ -55,7 +79,8 @@ class TransformerMemcpyImpl { std::map> provider_output_nodes_; onnxruntime::Graph& graph_; - std::string provider_; + const IExecutionProvider& provider_; + const ProviderTypeToProviderMap& providers_by_type_; }; /** Helper that returns a pointer to the corresponding TensorProto for a name if it is an initializer. @@ -73,17 +98,18 @@ static const onnx::TensorProto* GetInitializer(const Graph& graph, const std::st // very simple GraphTransformer that uses TransformerMemcpyImpl for each graph // and mainly provides the subgraph recursion functionality -common::Status MemcpyTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level, - const logging::Logger& logger) const { - for (auto& provider : provider_types_) { - if (!utils::ProviderIsCpuBased(provider)) { - TransformerMemcpyImpl copy_impl(graph, provider); +Status MemcpyTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level, + const logging::Logger& logger) const { + for (const auto provider : providers_) { + const auto& provider_type = provider->Type(); + if (!utils::ProviderIsCpuBased(*provider)) { + TransformerMemcpyImpl copy_impl(graph, *provider, providers_by_type_); int copy_node_counter = 0; auto current_modified = copy_impl.ModifyGraph(registry_manager_, logger, copy_node_counter); - if (copy_node_counter > 0 && provider == kCudaExecutionProvider) { + if (copy_node_counter > 0 && provider_type == kCudaExecutionProvider) { LOGS(logger, WARNING) << copy_node_counter << " Memcpy nodes are added to the graph " << graph.Name() - << " for " << provider + << " for " << provider_type << ". It might have negative impact on performance (including unable to run CUDA graph). " << "Set session_options.log_severity_level=1 to see the detail logs before this message."; } @@ -213,15 +239,42 @@ bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_regi return modified; } +static const IExecutionProvider* FindProviderByType(ProviderTypeToProviderMap providers_by_type, + std::string_view provider_type) { + const auto it = providers_by_type.find(provider_type); + if (it != providers_by_type.end()) { + return &*it->second; + } + return nullptr; +} + +bool TransformerMemcpyImpl::IsNodeCompatibleWithProvider(const onnxruntime::Node& node) const { + const auto& node_provider_type = node.GetExecutionProviderType(); + const auto* node_provider = FindProviderByType(providers_by_type_, node_provider_type); + ORT_ENFORCE(node_provider != nullptr, "Unable to get provider associated with provider type ", node_provider_type); + + // Same provider? + if (node_provider->Type() == provider_.Type()) { + return true; + } + + const auto& node_provider_device = node_provider->GetDevice(); + const auto& provider_device = provider_.GetDevice(); + + // Same provider device type and vendor? + if (node_provider_device.Type() == provider_device.Type() && + node_provider_device.Vendor() == provider_device.Vendor()) { + return true; + } + + return false; +} + void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelRegistryManager& kernel_registries, InitializedTensorSet& initializers_consumed, const logging::Logger& logger) { - auto node_provider_type = node.GetExecutionProviderType(); - if ((node_provider_type == provider_) || - (node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_) || - (node_provider_type == kCudaExecutionProvider && kNvTensorRTRTXExecutionProvider == provider_) || - (node_provider_type == kRocmExecutionProvider && kMIGraphXExecutionProvider == provider_)) { + if (IsNodeCompatibleWithProvider(node)) { provider_nodes_.insert(&node); // note KernelCreateInfo might be nullptr for custom kernel const KernelCreateInfo* kci = nullptr; @@ -268,9 +321,7 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, else provider_output_defs_.insert(arg); } - } else if (node_provider_type != kCudaExecutionProvider && node_provider_type != kTensorrtExecutionProvider && - node_provider_type != kCudaExecutionProvider && node_provider_type != kNvTensorRTRTXExecutionProvider && - node_provider_type != kRocmExecutionProvider && node_provider_type != kMIGraphXExecutionProvider) { + } else { for (const auto* arg : node.InputDefs()) { if (arg->Exists()) non_provider_input_defs_.insert(arg); @@ -297,7 +348,7 @@ void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg, const KernelRegistryManager& kernel_registries, const logging::Logger& logger) { for (auto& it : graph_.Nodes()) { - if (it.OpType() == "MemcpyFromHost" || it.OpType() == "MemcpyToHost") continue; + if (utils::IsMemcpyNode(it)) continue; auto input_it = std::find(it.MutableInputDefs().begin(), it.MutableInputDefs().end(), const_cast(arg)); auto output_it = @@ -309,10 +360,7 @@ void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg, if (arg_input_index == -1 && arg_output_index == -1) continue; auto node_provider_type = it.GetExecutionProviderType(); - if ((node_provider_type == provider_) || - (node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_) || - (node_provider_type == kCudaExecutionProvider && kNvTensorRTRTXExecutionProvider == provider_) || - (node_provider_type == kRocmExecutionProvider && kMIGraphXExecutionProvider == provider_)) { + if (IsNodeCompatibleWithProvider(it)) { const KernelCreateInfo* kci = nullptr; ORT_IGNORE_RETURN_VALUE(kernel_registries.SearchKernelRegistry(it, logger, &kci)); if (arg_input_index != -1) { @@ -325,9 +373,11 @@ void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg, } } -void TransformerMemcpyImpl::AddCopyNode(onnxruntime::NodeArg* arg, bool is_input, const logging::Logger& logger) { +void TransformerMemcpyImpl::AddCopyNode(onnxruntime::NodeArg* arg, + bool is_input, + const logging::Logger& logger) { // create unique name for new def - std::string new_def_name = graph_.GenerateNodeArgName(arg->Name() + "_" + provider_); + std::string new_def_name = graph_.GenerateNodeArgName(arg->Name() + "_" + provider_.Type()); auto* new_arg = &graph_.GetOrCreateNodeArg(new_def_name, arg->TypeAsProto()); auto* src_arg = is_input ? arg : new_arg; @@ -338,12 +388,14 @@ void TransformerMemcpyImpl::AddCopyNode(onnxruntime::NodeArg* arg, bool is_input const auto op_name = is_input ? "MemcpyFromHost" : "MemcpyToHost"; LOGS(logger, INFO) << "Add " << op_name << (is_input ? " after " : " before ") << arg->Name() - << " for " << provider_; + << " for " << provider_.Type(); auto& new_node = graph_.AddNode(new_node_name, op_name, "Copy from/to host memory", std::vector{src_arg}, std::vector{dst_arg}); - new_node.SetExecutionProviderType(provider_); + + new_node.SetExecutionProviderType(provider_.Type()); + std::map map = {{arg, new_arg}}; auto it = provider_input_nodes_.find(arg); if (it != provider_input_nodes_.end()) { diff --git a/onnxruntime/core/optimizer/transformer_memcpy.h b/onnxruntime/core/optimizer/transformer_memcpy.h index a2403d269f89b..f6b60a83fcf32 100644 --- a/onnxruntime/core/optimizer/transformer_memcpy.h +++ b/onnxruntime/core/optimizer/transformer_memcpy.h @@ -5,13 +5,19 @@ #include +#include "gsl/gsl" + #include "core/common/common.h" +#include "core/common/inlined_containers.h" +#include "core/framework/execution_provider.h" #include "core/framework/op_kernel.h" #include "core/framework/kernel_registry_manager.h" #include "core/optimizer/graph_transformer.h" namespace onnxruntime { +using ProviderTypeToProviderMap = InlinedHashMap>; + /** @Class MemcpyTransformer @@ -19,13 +25,14 @@ Transformer that inserts nodes to copy memory between devices when needed. */ class MemcpyTransformer : public GraphTransformer { public: - MemcpyTransformer(const std::vector& provider_types, const KernelRegistryManager& registry_manager) - : GraphTransformer("MemcpyTransformer"), provider_types_(provider_types), registry_manager_(std::cref(registry_manager)) {} + MemcpyTransformer(InlinedVector> providers, + const KernelRegistryManager& registry_manager); private: - common::Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; - const std::vector provider_types_; + const InlinedVector> providers_; + const ProviderTypeToProviderMap providers_by_type_; std::reference_wrapper registry_manager_; }; diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 5eac0523d953a..1030e368a5fd6 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -4,6 +4,7 @@ #include "core/providers/cpu/cpu_execution_provider.h" #include "core/framework/allocator_utils.h" +#include "core/framework/memcpy.h" #include "core/framework/op_kernel.h" #include "core/framework/kernel_registry.h" #include "core/framework/int4.h" @@ -27,6 +28,33 @@ struct KernelRegistryAndStatus { } // namespace namespace onnxruntime { + +// The MemcpyFromHost and MemcpyToHost kernels registered for the CPU EP are generic memcpy kernels. +// Other EPs may provide their own memcpy kernels. +// For a memcpy between host (CPU) and device of some other EP: +// - If the EP provides the corresponding memcpy kernel, it will be used. +// - Otherwise, one of these generic memcpy kernels will be used. + +ONNX_OPERATOR_KERNEL_EX( + MemcpyFromHost, + kOnnxDomain, + 1, + kCpuExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorAndSequenceTensorTypesIRv9()), + Memcpy); + +ONNX_OPERATOR_KERNEL_EX( + MemcpyToHost, + kOnnxDomain, + 1, + kCpuExecutionProvider, + (*KernelDefBuilder::Create()) + .OutputMemoryType(OrtMemTypeCPUOutput, 0) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorAndSequenceTensorTypesIRv9()), + Memcpy); + CPUExecutionProvider::CPUExecutionProvider(const CPUExecutionProviderInfo& info) : IExecutionProvider{onnxruntime::kCpuExecutionProvider}, info_{info} {} @@ -39,6 +67,8 @@ std::vector CPUExecutionProvider::CreatePreferredAllocators() { } // Forward declarations of op kernels +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MemcpyFromHost); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MemcpyToHost); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 10, Clip); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 21, Elu); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 21, HardSigmoid); @@ -1379,6 +1409,8 @@ KernelCreateInfo BuildKernelCreateInfo() { Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/llm/attention.cc b/onnxruntime/core/providers/cpu/llm/attention.cc index 4fc3cdf961d17..9d3736173dae2 100644 --- a/onnxruntime/core/providers/cpu/llm/attention.cc +++ b/onnxruntime/core/providers/cpu/llm/attention.cc @@ -47,14 +47,14 @@ void make_copy(MLFloat16* mask_data, const MLFloat16* mask template <> void make_copy(float* mask_data, const bool* mask_index, size_t size) { for (size_t i = 0; i < size; ++i) { - mask_data[i] = mask_index[i] ? 0.0f : std::numeric_limits::lowest(); + mask_data[i] = mask_index[i] ? 0.0f : negative_infinity(); } } template <> void make_copy(MLFloat16* mask_data, const bool* mask_index, size_t size) { for (size_t i = 0; i < size; ++i) { - mask_data[i] = mask_index[i] ? MLFloat16(0.f) : std::numeric_limits::lowest(); + mask_data[i] = mask_index[i] ? MLFloat16(0.f) : negative_infinity(); } } @@ -236,7 +236,7 @@ void AttentionBase::ComputeAttentionProbs(T* attention_probs, mask_data = static_cast(allocated_ptr); for (int s_i = 0; s_i < parameters.q_sequence_length; s_i++) { for (int m_i = parameters.past_sequence_length + s_i + 1; m_i < parameters.total_sequence_length; m_i++) { - mask_data[s_i * parameters.total_sequence_length + m_i] = std::numeric_limits::lowest(); + mask_data[s_i * parameters.total_sequence_length + m_i] = negative_infinity(); } } delete_mask_data = true; @@ -262,7 +262,7 @@ void AttentionBase::ComputeAttentionProbs(T* attention_probs, for (int i = 0; i < n_iter; ++i) { for (int s_i = 0; s_i < parameters.q_sequence_length; s_i++) { for (int m_i = parameters.past_sequence_length + s_i + 1; m_i < parameters.total_sequence_length; m_i++) { - mask_data[s_i * parameters.total_sequence_length + m_i + probs_matrix_size * i] = std::numeric_limits::lowest(); + mask_data[s_i * parameters.total_sequence_length + m_i + probs_matrix_size * i] = negative_infinity(); } } } @@ -317,7 +317,8 @@ void AttentionBase::ComputeAttentionProbs(T* attention_probs, } // handling GQA - std::ptrdiff_t ki = batch_i * parameters.kv_num_heads + head_i % parameters.kv_num_heads; + std::ptrdiff_t head_ki = head_i * parameters.kv_num_heads / parameters.q_num_heads; + std::ptrdiff_t ki = batch_i * parameters.kv_num_heads + head_ki; const T* k = K + k_input_chunk_length * ki; if (nullptr != present_key) { @@ -347,7 +348,7 @@ void AttentionBase::ComputeAttentionProbs(T* attention_probs, alpha, Q + q_input_chunk_length * parameters.q_num_heads * batch_i + head_i * parameters.head_size, parameters.head_size * parameters.q_num_heads, // lda - transposed_k ? K + k_input_chunk_length * parameters.kv_num_heads * batch_i + head_i * parameters.head_size : k, + transposed_k ? K + k_input_chunk_length * parameters.kv_num_heads * batch_i + head_ki * parameters.head_size : k, transposed_k ? parameters.head_size * parameters.kv_num_heads : parameters.head_size, // ldb beta, output, @@ -555,7 +556,8 @@ void AttentionBase::ComputeVxAttentionScore(T* output, // bu // handling GQA std::ptrdiff_t batch_i = i / num_heads; std::ptrdiff_t head_i = i % num_heads; - std::ptrdiff_t vi = batch_i * kv_num_heads + head_i % kv_num_heads; + std::ptrdiff_t head_vi = head_i * kv_num_heads / num_heads; + std::ptrdiff_t vi = batch_i * kv_num_heads + head_vi; const T* v = V + v_input_chunk_length * vi; if (nullptr != present_value) { @@ -579,15 +581,15 @@ void AttentionBase::ComputeVxAttentionScore(T* output, // bu // V is transposed but not QK. We use GemmEx with a different value for ldb. math::GemmEx(CblasNoTrans, CblasNoTrans, - sequence_length, // M - v_head_size, // N - total_sequence_length, // K - 1.f, // alpha - attention_probs + attention_probs_offset, // QK - total_sequence_length, // lda - transposed_v ? V + head_i * v_head_size + v_input_chunk_length * kv_num_heads * batch_i : v, // V - transposed_v ? v_head_size * kv_num_heads : v_head_size, // ldb - 0.f, // beta + sequence_length, // M + v_head_size, // N + total_sequence_length, // K + 1.f, // alpha + attention_probs + attention_probs_offset, // QK + total_sequence_length, // lda + transposed_v ? V + head_vi * v_head_size + v_input_chunk_length * kv_num_heads * batch_i : v, // V + transposed_v ? v_head_size * kv_num_heads : v_head_size, // ldb + 0.f, // beta output + ((batch_i * sequence_length * num_heads + head_i) * v_head_size), v_head_size * num_heads, // ldc nullptr); diff --git a/onnxruntime/core/providers/cpu/llm/attention.h b/onnxruntime/core/providers/cpu/llm/attention.h index 78889e48afb29..4fad6914f933d 100644 --- a/onnxruntime/core/providers/cpu/llm/attention.h +++ b/onnxruntime/core/providers/cpu/llm/attention.h @@ -9,6 +9,16 @@ namespace onnxruntime { +template +inline T negative_infinity() { + return -std::numeric_limits::infinity(); +} + +template <> +inline MLFloat16 negative_infinity() { + return MLFloat16(-std::numeric_limits::infinity()); +} + template class AttentionBase : public OpKernel { public: diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index 0ee18cc6799fc..62210d65848d1 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -1466,12 +1466,15 @@ std::unique_ptr NvExecutionProvider::GetSubGraph(SubGraph_t gra fused_inputs.erase(it); erased.insert(output); } - // Only when output is neither in input list nor erased list, add the output to output list + // Only when output is neither in input list nor erased list, and the output is consumed by another node, add the output to output list else if (erased.find(output) == erased.end()) { if (graph_output_names.find(output->Name()) != graph_output_names.end()) { graph_outputs_to_add[output] = output_order; } - fused_outputs[output] = output_order++; + + if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) { + fused_outputs[output] = output_order++; + } } } } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index b60f64db1734d..508d932459bf9 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2114,12 +2114,15 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph fused_inputs.erase(it); erased.insert(output); } - // Only when output is neither in input list nor erased list, add the output to output list + // Only when output is neither in input list nor erased list, and the output is consumed by another node, add the output to output list else if (erased.find(output) == erased.end()) { if (graph_output_names.find(output->Name()) != graph_output_names.end()) { graph_outputs_to_add[output] = output_order; } - fused_outputs[output] = output_order++; + + if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) { + fused_outputs[output] = output_order++; + } } } } diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index c0900c5ad28a0..112fd84c5ed45 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1517,12 +1517,12 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool // Insert copy node/s. { - std::vector provider_types; + InlinedVector> providers; for (auto& provider_ptr : execution_providers_) { - provider_types.push_back(provider_ptr->Type()); + providers.push_back(provider_ptr.get()); } - MemcpyTransformer copy_transformer{provider_types, kernel_registry_manager_}; + MemcpyTransformer copy_transformer{std::move(providers), kernel_registry_manager_}; ORT_RETURN_IF_ERROR_SESSIONID_(apply_transformer_once(copy_transformer, *session_logger_, graph)); } @@ -3383,17 +3383,58 @@ common::Status InferenceSession::GetInputOutputMemoryInfo(SessionInputOutputType for (const auto* def : def_list) { InlinedVector node_info_vec; + Status status; if (type == SessionInputOutputType::kOutput) { - ORT_RETURN_IF_ERROR(session_state_->GetOutputNodeInfo(def->Name(), node_info_vec)); + status = session_state_->GetOutputNodeInfo(def->Name(), node_info_vec); } else { - ORT_RETURN_IF_ERROR(session_state_->GetInputNodeInfo(def->Name(), node_info_vec)); + status = session_state_->GetInputNodeInfo(def->Name(), node_info_vec); } - // all entries are for the same OrtDevice so use the first one. - // we need to get an OrtMemoryInfo* that will remain valid, so we get the allocator for the OrtDevice - // from the session state and use its OrtMemoryInfo. - auto allocator = session_state_->GetAllocator(*node_info_vec.front().device); - memory_info.push_back(&allocator->Info()); + if (!status.IsOK()) { + if (type == SessionInputOutputType::kInput) { + return status; + } + + // Check first if this output is produced by an input that directly + // propagates to output with the same name. + status = session_state_->GetInputNodeInfo(def->Name(), node_info_vec); + if (status.IsOK()) { + // all entries are for the same OrtDevice so use the first one. + // we need to get an OrtMemoryInfo* that will remain valid, so we get the allocator for the OrtDevice + // from the session state and use its OrtMemoryInfo. + auto allocator = session_state_->GetAllocator(*node_info_vec.front().device); + memory_info.push_back(&allocator->Info()); + } else { + // Check if this output is produced by a constant initializer + // Pick the MemoryInfo from the initializer's OrtValue + const auto& ort_value_map = session_state_->GetOrtValueNameIdxMap(); + + OrtValueIndex ort_value_index; + status = ort_value_map.GetIdx(def->Name(), ort_value_index); + if (!status.IsOK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Failed to find node output or a constant initializer producing output: ", + def->Name(), "."); + } + + const auto& idx_to_ort_value = session_state_->GetInitializedTensors(); + auto it = idx_to_ort_value.find(ort_value_index); + if (it == idx_to_ort_value.end()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Failed to find node output or a constant initializer producing output: ", + def->Name(), "."); + } + const auto& tensor = it->second.Get(); + auto allocator = session_state_->GetAllocator(tensor.Location()); + memory_info.push_back(&allocator->Info()); + } + } else { + // all entries are for the same OrtDevice so use the first one. + // we need to get an OrtMemoryInfo* that will remain valid, so we get the allocator for the OrtDevice + // from the session state and use its OrtMemoryInfo. + auto allocator = session_state_->GetAllocator(*node_info_vec.front().device); + memory_info.push_back(&allocator->Info()); + } } return Status::OK(); @@ -3422,15 +3463,19 @@ common::Status InferenceSession::GetEpDeviceForInputs(InlinedVector node_info_vec; ORT_RETURN_IF_ERROR(session_state_->GetInputNodeInfo(def->Name(), node_info_vec)); - - // if we have a lot of inputs or there are a lot of execution providers it may be worth creating a map - // instead of doing a linear search each time. - const auto& ep_name = node_info_vec.front().p_node->GetExecutionProviderType(); - auto it = std::find_if(available_eps.begin(), available_eps.end(), [&ep_name](const OrtEpDevice* entry) { - return entry->ep_name == ep_name; - }); - - ep_devices.push_back(it != available_eps.end() ? *it : nullptr); + assert(!node_info_vec.empty()); + // If we have an input that is not consumed by any node, + // including nodes in subgraphs, then we return nullptr. + const auto* p_node = node_info_vec.front().p_node; + if (p_node != nullptr) { + const auto ep_name = p_node->GetExecutionProviderType(); + auto it = std::find_if(available_eps.begin(), available_eps.end(), [&ep_name](const OrtEpDevice* entry) { + return entry->ep_name == ep_name; + }); + ep_devices.push_back(it != available_eps.end() ? *it : nullptr); + } else { + ep_devices.push_back(nullptr); + } } return Status::OK(); diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 21d09df5cc4db..edc0cb6d2bd0f 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -4257,7 +4257,7 @@ static_assert(offsetof(OrtApi, SetEpDynamicOptions) / sizeof(void*) == 284, "Siz static_assert(offsetof(OrtApi, GetEpApi) / sizeof(void*) == 317, "Size of version 22 API cannot change"); // So that nobody forgets to finish an API version, this check will serve as a reminder: -static_assert(std::string_view(ORT_VERSION) == "1.23.0", +static_assert(std::string_view(ORT_VERSION) == "1.23.1", "ORT_Version change detected, please follow below steps to ensure OrtApi is updated properly"); // 1. Update the hardcoded version string in above static_assert to silence it // 2. If there were any APIs added to ort_api_1_to_23 above: diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index 35abad5760c32..4c3313046457c 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -199,6 +199,18 @@ def get_modelmeta(self) -> onnxruntime.ModelMetadata: "Return the metadata. See :class:`onnxruntime.ModelMetadata`." return self._model_meta + def get_input_memory_infos(self) -> Sequence[onnxruntime.MemoryInfo]: + "Return the memory info for the inputs." + return self._input_meminfos + + def get_output_memory_infos(self) -> Sequence[onnxruntime.MemoryInfo]: + "Return the memory info for the outputs." + return self._output_meminfos + + def get_input_epdevices(self) -> Sequence[onnxruntime.OrtEpDevice]: + "Return the execution providers for the inputs." + return self._input_epdevices + def get_providers(self) -> Sequence[str]: "Return list of registered execution providers." return self._providers @@ -576,6 +588,9 @@ def _create_inference_session(self, providers, provider_options, disabled_optimi self._inputs_meta = self._sess.inputs_meta self._outputs_meta = self._sess.outputs_meta self._overridable_initializers = self._sess.overridable_initializers + self._input_meminfos = self._sess.input_meminfos + self._output_meminfos = self._sess.output_meminfos + self._input_epdevices = self._sess.input_epdevices self._model_meta = self._sess.model_meta self._providers = self._sess.get_providers() self._provider_options = self._sess.get_provider_options() @@ -589,6 +604,9 @@ def _reset_session(self, providers, provider_options) -> None: self._inputs_meta = None self._outputs_meta = None self._overridable_initializers = None + self._input_meminfos = None + self._output_meminfos = None + self._input_epdevices = None self._model_meta = None self._providers = None self._provider_options = None @@ -1134,6 +1152,15 @@ def update_inplace(self, np_arr) -> None: self._ortvalue.update_inplace(np_arr) +def copy_tensors(src: Sequence[OrtValue], dst: Sequence[OrtValue], stream=None) -> None: + """ + Copy tensor data from source OrtValue sequence to destination OrtValue sequence. + """ + c_sources = [s._get_c_value() for s in src] + c_dsts = [d._get_c_value() for d in dst] + C.copy_tensors(c_sources, c_dsts, stream) + + class OrtDevice: """ A data structure that exposes the underlying C++ OrtDevice @@ -1146,6 +1173,7 @@ def __init__(self, c_ort_device): if isinstance(c_ort_device, C.OrtDevice): self._ort_device = c_ort_device else: + # An end user won't hit this error raise ValueError( "`Provided object` needs to be of type `onnxruntime.capi.onnxruntime_pybind11_state.OrtDevice`" ) @@ -1188,6 +1216,9 @@ def device_type(self): def device_vendor_id(self): return self._ort_device.vendor_id() + def device_mem_type(self): + return self._ort_device.mem_type() + class SparseTensor: """ diff --git a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc index 1fe7ab0884f9c..d74663ddb63d7 100644 --- a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc @@ -333,7 +333,7 @@ void addOrtValueMethods(pybind11::module& m) { }) #endif // Get a pointer to Tensor data - .def("data_ptr", [](OrtValue* ml_value) -> int64_t { + .def("data_ptr", [](OrtValue* ml_value) -> uintptr_t { // TODO: Assumes that the OrtValue is a Tensor, make this generic to handle non-Tensors ORT_ENFORCE(ml_value->IsTensor(), "Only OrtValues that are Tensors are currently supported"); @@ -344,7 +344,7 @@ void addOrtValueMethods(pybind11::module& m) { } // Should cover x86 and x64 platforms - return reinterpret_cast(tensor->MutableDataRaw()); + return reinterpret_cast(tensor->MutableDataRaw()); }) .def("device_name", [](const OrtValue* ort_value) -> std::string { if (ort_value->IsTensor()) { diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index e370518b1fffb..c17acc9ffff3a 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -22,6 +22,7 @@ #include "core/framework/data_transfer_utils.h" #include "core/framework/data_types_internal.h" #include "core/framework/error_code_helper.h" +#include "core/framework/plugin_ep_stream.h" #include "core/framework/provider_options_utils.h" #include "core/framework/random_seed.h" #include "core/framework/sparse_tensor.h" @@ -1584,6 +1585,18 @@ void addGlobalMethods(py::module& m) { }, R"pbdoc("Validate a compiled model's compatibility information for one or more EP devices.)pbdoc"); + m.def( + "copy_tensors", + [](const std::vector& src, const std::vector& dest, py::object& py_arg) { + const OrtEnv* ort_env = GetOrtEnv(); + OrtSyncStream* stream = nullptr; + if (!py_arg.is_none()) { + stream = py_arg.cast(); + } + Ort::ThrowOnError(Ort::GetApi().CopyTensors(ort_env, src.data(), dest.data(), stream, src.size())); + }, + R"pbdoc("Copy tensors from sources to destinations using specified stream handle (or None))pbdoc"); + #if defined(USE_OPENVINO) || defined(USE_OPENVINO_PROVIDER_INTERFACE) m.def( "get_available_openvino_device_ids", []() -> std::vector { @@ -1785,6 +1798,16 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra .value("CPU", OrtMemTypeCPU) .value("DEFAULT", OrtMemTypeDefault); + py::enum_(m, "OrtMemoryInfoDeviceType") + .value("CPU", OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_CPU) + .value("GPU", OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU) + .value("NPU", OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_NPU) + .value("FPGA", OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_FPGA); + + py::enum_(m, "OrtDeviceMemoryType") + .value("DEFAULT", OrtDeviceMemoryType_DEFAULT) + .value("HOST_ACCESSIBLE", OrtDeviceMemoryType_HOST_ACCESSIBLE); + py::class_ device(m, "OrtDevice", R"pbdoc(ONNXRuntime device information.)pbdoc"); device.def(py::init()) .def(py::init([](OrtDevice::DeviceType type, @@ -1813,6 +1836,7 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra .def("device_id", &OrtDevice::Id, R"pbdoc(Device Id.)pbdoc") .def("device_type", &OrtDevice::Type, R"pbdoc(Device Type.)pbdoc") .def("vendor_id", &OrtDevice::Vendor, R"pbdoc(Vendor Id.)pbdoc") + .def("mem_type", &OrtDevice::MemType, R"pbdoc(Device Memory Type.)pbdoc") // generic device types that are typically used with a vendor id. .def_static("cpu", []() { return OrtDevice::CPU; }) .def_static("gpu", []() { return OrtDevice::GPU; }) @@ -1863,36 +1887,55 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra }, R"pbdoc(Hardware device's metadata as string key/value pairs.)pbdoc"); + py::class_ py_sync_stream(m, "OrtSyncStream", + R"pbdoc(Represents a synchronization stream for model inference.)pbdoc"); + py::class_ py_ep_device(m, "OrtEpDevice", R"pbdoc(Represents a hardware device that an execution provider supports for model inference.)pbdoc"); py_ep_device.def_property_readonly( "ep_name", - [](OrtEpDevice* ep_device) -> std::string { return ep_device->ep_name; }, + [](const OrtEpDevice* ep_device) -> std::string { return ep_device->ep_name; }, R"pbdoc(The execution provider's name.)pbdoc") .def_property_readonly( "ep_vendor", - [](OrtEpDevice* ep_device) -> std::string { return ep_device->ep_vendor; }, + [](const OrtEpDevice* ep_device) -> std::string { return ep_device->ep_vendor; }, R"pbdoc(The execution provider's vendor name.)pbdoc") .def_property_readonly( "ep_metadata", - [](OrtEpDevice* ep_device) -> std::map { + [](const OrtEpDevice* ep_device) -> std::map { return ep_device->ep_metadata.Entries(); }, R"pbdoc(The execution provider's additional metadata for the OrtHardwareDevice.)pbdoc") .def_property_readonly( "ep_options", - [](OrtEpDevice* ep_device) -> std::map { + [](const OrtEpDevice* ep_device) -> std::map { return ep_device->ep_options.Entries(); }, R"pbdoc(The execution provider's options used to configure the provider to use the OrtHardwareDevice.)pbdoc") .def_property_readonly( "device", - [](OrtEpDevice* ep_device) -> const OrtHardwareDevice& { + [](const OrtEpDevice* ep_device) -> const OrtHardwareDevice& { return *ep_device->device; }, R"pbdoc(The OrtHardwareDevice instance for the OrtEpDevice.)pbdoc", - py::return_value_policy::reference_internal); + py::return_value_policy::reference_internal) + .def( + "memory_info", + [](const OrtEpDevice* ep_device, OrtDeviceMemoryType memory_type) -> const OrtMemoryInfo* { + Ort::ConstEpDevice ep_dev(ep_device); + return static_cast(ep_dev.GetMemoryInfo(memory_type)); + }, + R"pbdoc(The OrtMemoryInfo instance for the OrtEpDevice specific to the device memory type.)pbdoc", + py::return_value_policy::reference_internal) + .def( + "create_sync_stream", + [](const OrtEpDevice* ep_device) -> std::unique_ptr { + Ort::ConstEpDevice ep_dev(ep_device); + Ort::SyncStream stream = ep_dev.CreateSyncStream(); + return std::unique_ptr(stream.release()); + }, + R"pbdoc(The OrtSyncStream instance for the OrtEpDevice.)pbdoc"); py::class_ ort_arena_cfg_binding(m, "OrtArenaCfg"); // Note: Doesn't expose initial_growth_chunk_sizes_bytes/max_power_of_two_extend_bytes option. @@ -1938,25 +1981,28 @@ for model inference.)pbdoc"); .def_readwrite("max_power_of_two_extend_bytes", &OrtArenaCfg::max_power_of_two_extend_bytes); py::class_ ort_memory_info_binding(m, "OrtMemoryInfo"); - ort_memory_info_binding.def(py::init([](const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) { - if (strcmp(name, onnxruntime::CPU) == 0) { - return std::make_unique(onnxruntime::CPU, type, OrtDevice(), mem_type); - } else if (strcmp(name, onnxruntime::CUDA) == 0) { - return std::make_unique( - onnxruntime::CUDA, type, - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, - static_cast(id)), - mem_type); - } else if (strcmp(name, onnxruntime::CUDA_PINNED) == 0) { - return std::make_unique( - onnxruntime::CUDA_PINNED, type, - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::NVIDIA, - static_cast(id)), - mem_type); - } else { - throw std::runtime_error("Specified device is not supported."); - } - })); + ort_memory_info_binding.def( + py::init([](const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) { + Ort::MemoryInfo result(name, type, id, mem_type); + return std::unique_ptr(result.release()); + })) + .def_static( + "create_v2", + [](const char* name, OrtMemoryInfoDeviceType device_type, uint32_t vendor_id, + int32_t device_id, OrtDeviceMemoryType device_mem_type, size_t alignment, OrtAllocatorType type) { + Ort::MemoryInfo result(name, device_type, vendor_id, device_id, device_mem_type, alignment, type); + return std::unique_ptr(result.release()); + }, + R"pbdoc(Create an OrtMemoryInfo instance using CreateMemoryInfo_V2())pbdoc") + .def_property_readonly("name", [](const OrtMemoryInfo* mem_info) -> std::string { return mem_info->name; }, R"pbdoc(Arbitrary name supplied by the user)pbdoc") + .def_property_readonly("device_id", [](const OrtMemoryInfo* mem_info) -> int { return mem_info->device.Id(); }, R"pbdoc(Device Id.)pbdoc") + .def_property_readonly("mem_type", [](const OrtMemoryInfo* mem_info) -> OrtMemType { return mem_info->mem_type; }, R"pbdoc(OrtMemoryInfo memory type.)pbdoc") + .def_property_readonly("allocator_type", [](const OrtMemoryInfo* mem_info) -> OrtAllocatorType { return mem_info->alloc_type; }, R"pbdoc(Allocator type)pbdoc") + .def_property_readonly("device_mem_type", [](const OrtMemoryInfo* mem_info) -> OrtDeviceMemoryType { + auto mem_type = mem_info->device.MemType(); + return (mem_type == OrtDevice::MemType::DEFAULT) ? + OrtDeviceMemoryType_DEFAULT: OrtDeviceMemoryType_HOST_ACCESSIBLE ; }, R"pbdoc(Device memory type (Device or Host accessible).)pbdoc") + .def_property_readonly("device_vendor_id", [](const OrtMemoryInfo* mem_info) -> uint32_t { return mem_info->device.Vendor(); }); py::class_ sess(m, "SessionOptions", R"pbdoc(Configuration information for a session.)pbdoc"); @@ -2653,6 +2699,33 @@ including arg name, arg type (contains both type and shape).)pbdoc") auto res = sess->GetSessionHandle()->GetModelMetadata(); OrtPybindThrowIfError(res.first); return *(res.second); }, py::return_value_policy::reference_internal) + .def_property_readonly("input_meminfos", [](const PyInferenceSession* sess) -> py::list { + Ort::ConstSession session(reinterpret_cast(sess->GetSessionHandle())); + auto inputs_mem_info = session.GetMemoryInfoForInputs(); + py::list result; + for (const auto& info : inputs_mem_info) { + const auto* p_info = static_cast(info); + result.append(py::cast(p_info, py::return_value_policy::reference)); + } + return result; }) + .def_property_readonly("output_meminfos", [](const PyInferenceSession* sess) -> py::list { + Ort::ConstSession session(reinterpret_cast(sess->GetSessionHandle())); + auto outputs_mem_info = session.GetMemoryInfoForOutputs(); + py::list result; + for (const auto& info : outputs_mem_info) { + const auto* p_info = static_cast(info); + result.append(py::cast(p_info, py::return_value_policy::reference)); + } + return result; }) + .def_property_readonly("input_epdevices", [](const PyInferenceSession* sess) -> py::list { + Ort::ConstSession session(reinterpret_cast(sess->GetSessionHandle())); + auto ep_devices = session.GetEpDeviceForInputs(); + py::list result; + for (const auto& device : ep_devices) { + const auto* p_device = static_cast(device); + result.append(py::cast(p_device, py::return_value_policy::reference)); + } + return result; }) .def("run_with_iobinding", [](PyInferenceSession* sess, SessionIOBinding& io_binding, RunOptions* run_options = nullptr) -> void { Status status; diff --git a/onnxruntime/test/autoep/library/ep.cc b/onnxruntime/test/autoep/library/ep.cc index e4265713d2d0a..5d8245618dcd6 100644 --- a/onnxruntime/test/autoep/library/ep.cc +++ b/onnxruntime/test/autoep/library/ep.cc @@ -233,7 +233,7 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG for (const auto& node : nodes) { auto op_type = node.GetOperatorType(); - if (op_type != "Mul") { + if (op_type == "Mul") { // Check that Mul has inputs/output of type float std::vector inputs = node.GetInputs(); std::vector outputs = node.GetOutputs(); @@ -248,11 +248,36 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG continue; // Input or output is not of type float } + { + const auto input_0_shape = GetTensorShape(inputs[0]), + input_1_shape = GetTensorShape(inputs[1]); + + if (!input_0_shape.has_value() || !input_1_shape.has_value()) { + continue; // unable to get input shape + } + + const auto is_static_shape = [](gsl::span shape) -> bool { + return std::all_of(shape.begin(), shape.end(), [](int64_t dim) { return dim >= 0; }); + }; + + if (!is_static_shape(*input_0_shape) || !is_static_shape(*input_1_shape)) { + continue; // input shape has dynamic dimensions + } + + if (*input_0_shape != *input_1_shape) { + continue; // input shapes do not match (no broadcasting support for now) + } + } + supported_nodes.push_back(node); // Only support a single Mul for now. break; } } + if (supported_nodes.empty()) { + return nullptr; + } + // Create (optional) fusion options for the supported nodes to fuse. OrtNodeFusionOptions node_fusion_options = {}; node_fusion_options.ort_version_supported = ORT_API_VERSION; @@ -317,7 +342,7 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const Ort::ConstNode fused_node{fused_nodes[0]}; auto ep_name = fused_node.GetEpName(); - if (ep_name != "example_ep") { + if (ep_name != ep->name_) { Ort::Status status("The fused node is expected to assigned to this EP to run on", ORT_EP_FAIL); return status.release(); } diff --git a/onnxruntime/test/autoep/library/ep_stream_support.cc b/onnxruntime/test/autoep/library/ep_stream_support.cc index 1f6c16a8cb358..c648474d4fad7 100644 --- a/onnxruntime/test/autoep/library/ep_stream_support.cc +++ b/onnxruntime/test/autoep/library/ep_stream_support.cc @@ -61,7 +61,12 @@ OrtStatus* ORT_API_CALL NotificationImpl::ActivateImpl(_In_ OrtSyncNotificationI /*static*/ OrtStatus* ORT_API_CALL NotificationImpl::WaitOnDeviceImpl(_In_ OrtSyncNotificationImpl* this_ptr, _In_ OrtSyncStream* stream) noexcept { + if (stream == nullptr) { + return nullptr; + } + auto& impl = *static_cast(this_ptr); + void* handle = impl.ort_api.SyncStream_GetHandle(stream); static_cast(handle); diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_utils.cc b/onnxruntime/test/autoep/library/example_plugin_ep_utils.cc index 263b4d208bd91..8b36f5f4e9a13 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_utils.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep_utils.cc @@ -35,3 +35,14 @@ void IsFloatTensor(Ort::ConstValueInfo value_info, bool& result) { } result = true; } + +std::optional> GetTensorShape(Ort::ConstValueInfo value_info) { + const auto type_info = value_info.TypeInfo(); + const auto onnx_type = type_info.GetONNXType(); + if (onnx_type != ONNX_TYPE_TENSOR) { + return std::nullopt; + } + + const auto type_shape = type_info.GetTensorTypeAndShapeInfo(); + return type_shape.GetShape(); +} diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_utils.h b/onnxruntime/test/autoep/library/example_plugin_ep_utils.h index e8c086d38a7cb..decc89251dc7b 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_utils.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep_utils.h @@ -4,7 +4,9 @@ #pragma once #include +#include #include +#include #define ORT_API_MANUAL_INIT #include "onnxruntime_cxx_api.h" @@ -108,3 +110,6 @@ OrtStatus* GetSessionConfigEntryOrDefault(const OrtApi& ort_api, const OrtSessio // Returns true (via output parameter) if the given OrtValueInfo represents a float tensor. void IsFloatTensor(Ort::ConstValueInfo value_info, bool& result); + +// Gets the tensor shape from `value_info`. Returns std::nullopt if `value_info` is not a tensor. +std::optional> GetTensorShape(Ort::ConstValueInfo value_info); diff --git a/onnxruntime/test/autoep/test_execution.cc b/onnxruntime/test/autoep/test_execution.cc index 0f4a654f116c4..78be22d082692 100644 --- a/onnxruntime/test/autoep/test_execution.cc +++ b/onnxruntime/test/autoep/test_execution.cc @@ -22,7 +22,8 @@ namespace onnxruntime { namespace test { namespace { -void RunModelWithPluginEp(Ort::SessionOptions& session_options) { + +void RunMulModelWithPluginEp(const Ort::SessionOptions& session_options) { Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), session_options); // Create input @@ -47,6 +48,38 @@ void RunModelWithPluginEp(Ort::SessionOptions& session_options) { gsl::span output_span(output_data, 6); EXPECT_THAT(output_span, ::testing::ElementsAre(2, 4, 6, 8, 10, 12)); } + +void RunPartiallySupportedModelWithPluginEp(const Ort::SessionOptions& session_options) { + // This model has Add -> Mul -> Add. The example plugin EP only supports Mul. + Ort::Session session(*ort_env, ORT_TSTR("testdata/add_mul_add.onnx"), session_options); + + // Create inputs + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + std::vector shape = {3, 2}; + + std::vector a_data{1, 2, 3, 4, 5, 6}; + std::vector b_data{2, 3, 4, 5, 6, 7}; + + std::vector ort_inputs{}; + ort_inputs.emplace_back( + Ort::Value::CreateTensor(memory_info, a_data.data(), a_data.size(), shape.data(), shape.size())); + ort_inputs.emplace_back( + Ort::Value::CreateTensor(memory_info, b_data.data(), b_data.size(), shape.data(), shape.size())); + + std::array ort_input_names{"A", "B"}; + + // Run session and get outputs + std::array output_names{"C"}; + std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), + ort_inputs.size(), output_names.data(), output_names.size()); + + // Check expected output values + Ort::Value& ort_output = ort_outputs[0]; + const float* output_data = ort_output.GetTensorData(); + gsl::span output_span(output_data, 6); + EXPECT_THAT(output_span, ::testing::ElementsAre(7, 17, 31, 49, 71, 97)); +} + } // namespace // Creates a session with the example plugin EP and runs a model with a single Mul node. @@ -61,7 +94,7 @@ TEST(OrtEpLibrary, PluginEp_AppendV2_MulInference) { std::unordered_map ep_options; session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); - RunModelWithPluginEp(session_options); + RunMulModelWithPluginEp(session_options); } // Creates a session with the example plugin EP and runs a model with a single Mul node. @@ -74,10 +107,23 @@ TEST(OrtEpLibrary, PluginEp_PreferCpu_MulInference) { // PREFER_CPU pick our example EP over ORT CPU EP. TODO: Actually assert this. Ort::SessionOptions session_options; session_options.SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy_PREFER_CPU); - RunModelWithPluginEp(session_options); + RunMulModelWithPluginEp(session_options); } } +TEST(OrtEpLibrary, PluginEp_AppendV2_PartiallySupportedModelInference) { + RegisteredEpDeviceUniquePtr example_ep; + Utils::RegisterAndGetExampleEp(*ort_env, example_ep); + Ort::ConstEpDevice plugin_ep_device(example_ep.get()); + + // Create session with example plugin EP + Ort::SessionOptions session_options; + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + + RunPartiallySupportedModelWithPluginEp(session_options); +} + // Generate an EPContext model with a plugin EP. // This test uses the OrtCompileApi but could also be done by setting the appropriate session option configs. TEST(OrtEpLibrary, PluginEp_GenEpContextModel) { @@ -98,6 +144,7 @@ TEST(OrtEpLibrary, PluginEp_GenEpContextModel) { // Create model compilation options from the session options. Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + compile_options.SetFlags(OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED); compile_options.SetInputModelPath(input_model_file); compile_options.SetOutputModelPath(output_model_file); diff --git a/onnxruntime/test/framework/memcpy_transformer_test.cc b/onnxruntime/test/framework/memcpy_transformer_test.cc index 6e86e5b58aead..01b253446974b 100644 --- a/onnxruntime/test/framework/memcpy_transformer_test.cc +++ b/onnxruntime/test/framework/memcpy_transformer_test.cc @@ -71,8 +71,18 @@ void ExpectCopy(const onnxruntime::Node& source, const std::string copy_op, } EXPECT_TRUE(false) << "Copy node expected but not found"; } + #ifdef USE_CUDA +static InlinedVector> GetNotNullProviderPtrs( + const ExecutionProviders& providers) { + InlinedVector> not_null_provider_ptrs{}; + for (auto& provider_ptr : providers) { + not_null_provider_ptrs.emplace_back(provider_ptr.get()); + } + return not_null_provider_ptrs; +} + TEST(TransformerTest, MemcpyTransformerTest) { std::unordered_map domain_to_version; domain_to_version[kOnnxDomain] = 7; @@ -112,7 +122,11 @@ TEST(TransformerTest, MemcpyTransformerTest) { KernelRegistryManager test_registry_manager; ASSERT_STATUS_OK(test_registry_manager.RegisterKernels(execution_providers)); - MemcpyTransformer transformer({onnxruntime::kCudaExecutionProvider}, test_registry_manager); + InlinedVector> providers; + for (auto& provider_ptr : execution_providers) { + providers.push_back(provider_ptr.get()); + } + MemcpyTransformer transformer(GetNotNullProviderPtrs(execution_providers), test_registry_manager); bool modified = false; status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()); @@ -167,7 +181,7 @@ TEST(TransformerTest, MemcpyTransformerTestCudaFirst) { KernelRegistryManager test_registry_manager; ASSERT_STATUS_OK(test_registry_manager.RegisterKernels(execution_providers)); - MemcpyTransformer transformer({onnxruntime::kCudaExecutionProvider}, test_registry_manager); + MemcpyTransformer transformer(GetNotNullProviderPtrs(execution_providers), test_registry_manager); bool modified = false; status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()); @@ -262,6 +276,8 @@ TEST(TransformerTest, TestInitializerDuplicationInSubgraph) { if_node.AddAttribute("then_branch", subgraph.ToGraphProto()); if_node.AddAttribute("else_branch", subgraph.ToGraphProto()); + if_node.SetExecutionProviderType(onnxruntime::kCpuExecutionProvider); + onnxruntime::Graph* subgraph_1 = if_node.GetMutableGraphAttribute("then_branch"); for (auto& node : subgraph_1->Nodes()) { if (node.Name() == "node2") { @@ -287,7 +303,7 @@ TEST(TransformerTest, TestInitializerDuplicationInSubgraph) { KernelRegistryManager test_registry_manager; ASSERT_STATUS_OK(test_registry_manager.RegisterKernels(execution_providers)); - MemcpyTransformer transformer({onnxruntime::kCudaExecutionProvider}, test_registry_manager); + MemcpyTransformer transformer(GetNotNullProviderPtrs(execution_providers), test_registry_manager); bool modified = false; ASSERT_STATUS_OK(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger())); @@ -329,7 +345,7 @@ TEST(TransformerTest, MemcpyTransformerTestGraphInputConsumedOnMultipleDevices) KernelRegistryManager test_registry_manager; ASSERT_STATUS_OK(test_registry_manager.RegisterKernels(execution_providers)); - MemcpyTransformer transformer({onnxruntime::kCudaExecutionProvider}, test_registry_manager); + MemcpyTransformer transformer(GetNotNullProviderPtrs(execution_providers), test_registry_manager); bool modified = false; status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()); @@ -398,6 +414,8 @@ TEST(TransformerTest, MemcpyTransformerTestImplicitInputConsumedOnMultipleDevice if_node.AddAttribute("then_branch", subgraph.ToGraphProto()); if_node.AddAttribute("else_branch", subgraph.ToGraphProto()); + if_node.SetExecutionProviderType(onnxruntime::kCpuExecutionProvider); + graph.SetInputs({&i1_def, &i2_def}); onnxruntime::Graph* subgraph_1 = if_node.GetMutableGraphAttribute("then_branch"); @@ -431,7 +449,7 @@ TEST(TransformerTest, MemcpyTransformerTestImplicitInputConsumedOnMultipleDevice KernelRegistryManager test_registry_manager; ASSERT_STATUS_OK(test_registry_manager.RegisterKernels(execution_providers)); - MemcpyTransformer transformer({onnxruntime::kCudaExecutionProvider}, test_registry_manager); + MemcpyTransformer transformer(GetNotNullProviderPtrs(execution_providers), test_registry_manager); bool modified = false; status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()); diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index bef0bdd5295be..d56212510d2a9 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -795,6 +795,24 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); // Please make no more changes to the list static const ORTCHAR_T* immutable_broken_tests[] = { + // pending ONNX update + ORT_TSTR("attention_3d_gqa"), + ORT_TSTR("attention_3d_gqa_attn_mask"), + ORT_TSTR("attention_3d_gqa_causal"), + ORT_TSTR("attention_3d_gqa_scaled"), + ORT_TSTR("attention_3d_gqa_softcap"), + ORT_TSTR("attention_3d_gqa_with_past_and_present"), + ORT_TSTR("attention_4d_gqa"), + ORT_TSTR("attention_4d_gqa_attn_mask"), + ORT_TSTR("attention_4d_gqa_causal"), + ORT_TSTR("attention_4d_gqa_scaled"), + ORT_TSTR("attention_4d_gqa_softcap"), + ORT_TSTR("attention_4d_gqa_with_past_and_present"), + ORT_TSTR("attention_4d_diff_heads_mask4d_padded_kv"), + ORT_TSTR("attention_4d_gqa_with_past_and_present_fp16"), + ORT_TSTR("attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_causal"), + ORT_TSTR("attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal"), + // unsupported case ORT_TSTR("AvgPool1d"), ORT_TSTR("AvgPool1d_stride"), ORT_TSTR("AvgPool2d"), diff --git a/onnxruntime/test/optimizer/transpose_optimizer_test.cc b/onnxruntime/test/optimizer/transpose_optimizer_test.cc index f6fce37322c10..c382612a6dff8 100644 --- a/onnxruntime/test/optimizer/transpose_optimizer_test.cc +++ b/onnxruntime/test/optimizer/transpose_optimizer_test.cc @@ -4518,7 +4518,7 @@ TEST(TransposeOptimizerTests, QnnTransposeReshape) { // changes during the layout transformation process. ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kDebugLayoutTransformation, "1")); - using InternalTestingEP = onnxruntime::internal_testing_ep::InternalTestingExecutionProvider; + using InternalTestingEP = internal_testing_ep::InternalTestingExecutionProvider; // set the test EP to support all ops in the model so that the layout transform applies to all nodes const std::unordered_set empty_set; @@ -4544,7 +4544,7 @@ TEST(TransposeOptimizerTests, QnnTransposeReshape) { "with the exception of the initial node prior to the Conv"; // all nodes should be assigned to the internal testing EP, which also means they should be in NHWC layout - std::string expected_ep(onnxruntime::utils::kInternalTestingExecutionProvider); + std::string expected_ep(internal_testing_ep::kInternalTestingExecutionProvider); for (const auto& node : graph.Nodes()) { EXPECT_TRUE(node.GetExecutionProviderType() == expected_ep) << node.OpType() << " node named '" << node.Name() << "' was not assigned to the internal testing EP."; @@ -4564,7 +4564,7 @@ TEST(TransposeOptimizerTests, QnnTransposeReshapeQDQ) { SessionOptions so; - using InternalTestingEP = onnxruntime::internal_testing_ep::InternalTestingExecutionProvider; + using InternalTestingEP = internal_testing_ep::InternalTestingExecutionProvider; // set the test EP to support all ops in the model so that the layout transform applies to all nodes const std::unordered_set empty_set; @@ -4590,7 +4590,7 @@ TEST(TransposeOptimizerTests, QnnTransposeReshapeQDQ) { "with the exception of the initial node prior to the Conv"; // all nodes should be assigned to the internal testing EP, which also means they should be in NHWC layout - std::string expected_ep(onnxruntime::utils::kInternalTestingExecutionProvider); + std::string expected_ep(internal_testing_ep::kInternalTestingExecutionProvider); for (const auto& node : graph.Nodes()) { EXPECT_TRUE(node.GetExecutionProviderType() == expected_ep) << node.OpType() << " node named '" << node.Name() << "' was not assigned to the internal testing EP."; @@ -4606,7 +4606,7 @@ TEST(TransposeOptimizerTests, QnnResizeOpset11) { // Uncomment to debug // ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kDebugLayoutTransformation, "1")); - using InternalTestingEP = onnxruntime::internal_testing_ep::InternalTestingExecutionProvider; + using InternalTestingEP = internal_testing_ep::InternalTestingExecutionProvider; // set the test EP to support all ops in the model so that the layout transform applies to all nodes const std::unordered_set empty_set; @@ -4620,7 +4620,7 @@ TEST(TransposeOptimizerTests, QnnResizeOpset11) { const auto& graph = session.GetGraph(); // all nodes should be assigned to the internal testing EP, which also means they should be in NHWC layout - std::string expected_ep(onnxruntime::utils::kInternalTestingExecutionProvider); + std::string expected_ep(internal_testing_ep::kInternalTestingExecutionProvider); for (const auto& node : graph.Nodes()) { EXPECT_TRUE(node.GetExecutionProviderType() == expected_ep) << node.OpType() << " node named '" << node.Name() << "' was not assigned to the internal testing EP."; @@ -4648,7 +4648,7 @@ TEST(TransposeOptimizerTests, QnnTransposeNonConstBroadcastInput) { // ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kDebugLayoutTransformation, "1")); - using InternalTestingEP = onnxruntime::internal_testing_ep::InternalTestingExecutionProvider; + using InternalTestingEP = internal_testing_ep::InternalTestingExecutionProvider; // set the test EP to support all ops in the model so that the layout transform applies to all nodes const std::unordered_set empty_set; @@ -4666,7 +4666,7 @@ TEST(TransposeOptimizerTests, QnnTransposeNonConstBroadcastInput) { ASSERT_EQ(op_to_count["Transpose"], 3) << "Should have Transpose on 2 inputs and one on output."; // all nodes should be assigned to the internal testing EP, which also means they should be in NHWC layout - std::string expected_ep(onnxruntime::utils::kInternalTestingExecutionProvider); + std::string expected_ep(internal_testing_ep::kInternalTestingExecutionProvider); for (const auto& node : graph.Nodes()) { EXPECT_EQ(node.GetExecutionProviderType(), expected_ep) << node.OpType() << " node named '" << node.Name() << "' was not assigned to the internal testing EP."; @@ -4699,7 +4699,7 @@ TEST(TransposeOptimizerTests, LayoutTransformFixStuckTransposeWithoutDQ) { // ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kDebugLayoutTransformation, "1")); - using InternalTestingEP = onnxruntime::internal_testing_ep::InternalTestingExecutionProvider; + using InternalTestingEP = internal_testing_ep::InternalTestingExecutionProvider; // Set the test EP to support all ops in the model so that the layout transform applies to all nodes const std::unordered_set empty_set; @@ -4716,7 +4716,7 @@ TEST(TransposeOptimizerTests, LayoutTransformFixStuckTransposeWithoutDQ) { ASSERT_EQ(op_to_count["Transpose"], 2) << "Should have 2 transposes remaining."; - std::string expected_ep(onnxruntime::utils::kInternalTestingExecutionProvider); + std::string expected_ep(internal_testing_ep::kInternalTestingExecutionProvider); for (const auto& node : graph.Nodes()) { EXPECT_EQ(node.GetExecutionProviderType(), expected_ep) << node.OpType() << " node named '" << node.Name() << "' was not assigned to the internal testing EP."; @@ -4756,7 +4756,7 @@ TEST(TransposeOptimizerTests, LayoutTransformConstantFoldTransposeAndSqueeze) { // ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kDebugLayoutTransformation, "1")); - using InternalTestingEP = onnxruntime::internal_testing_ep::InternalTestingExecutionProvider; + using InternalTestingEP = internal_testing_ep::InternalTestingExecutionProvider; // Set the test EP to support all ops in the model so that the layout transform applies to all nodes const std::unordered_set empty_set; @@ -4777,7 +4777,7 @@ TEST(TransposeOptimizerTests, LayoutTransformConstantFoldTransposeAndSqueeze) { // 1 transpose is constant-folded, 1 is canceled, and 1 remains. ASSERT_EQ(op_to_count["Transpose"], 1) << "Should have 1 transpose remaining."; - std::string expected_ep(onnxruntime::utils::kInternalTestingExecutionProvider); + std::string expected_ep(internal_testing_ep::kInternalTestingExecutionProvider); for (const auto& node : graph.Nodes()) { EXPECT_EQ(node.GetExecutionProviderType(), expected_ep) << node.OpType() << " node named '" << node.Name() << "' was not assigned to the internal testing EP."; diff --git a/onnxruntime/test/providers/cpu/llm/attention_op_test.cc b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc index b4f6d328cacf7..54c2ed7d521db 100644 --- a/onnxruntime/test/providers/cpu/llm/attention_op_test.cc +++ b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc @@ -664,6 +664,7 @@ TEST(AttentionTest, Attention4DAttnPastPresent) { false, true, true // disable_cpu, disable_cuda, disable_dml ); } + TEST(AttentionTest, Attention4DAttnIsCausal) { int batch_size = 2; // Q.shape[0] int q_num_heads = 3; // Q.shape[1] @@ -828,6 +829,38 @@ TEST(AttentionTest, Attention4DDiffHeadsWithPastAndPresent) { ); } +TEST(AttentionTest, Attention3DGqaAttn) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 9; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 8; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 8; // V.shape[3] + int past_sequence_length = 0; // past_key.shape[2] and past_value.shape[2] + + // {2, 4, 72} + std::vector q = {0.548814f, 0.715189f, 0.602763f, 0.544883f, 0.423655f, 0.645894f, 0.437587f, 0.891773f, 0.963663f, 0.383442f, 0.791725f, 0.528895f, 0.568045f, 0.925597f, 0.071036f, 0.087129f, 0.020218f, 0.832620f, 0.778157f, 0.870012f, 0.978618f, 0.799159f, 0.461479f, 0.780529f, 0.118274f, 0.639921f, 0.143353f, 0.944669f, 0.521848f, 0.414662f, 0.264556f, 0.774234f, 0.456150f, 0.568434f, 0.018790f, 0.617635f, 0.612096f, 0.616934f, 0.943748f, 0.681820f, 0.359508f, 0.437032f, 0.697631f, 0.060225f, 0.666767f, 0.670638f, 0.210383f, 0.128926f, 0.315428f, 0.363711f, 0.570197f, 0.438602f, 0.988374f, 0.102045f, 0.208877f, 0.161310f, 0.653108f, 0.253292f, 0.466311f, 0.244426f, 0.158970f, 0.110375f, 0.656330f, 0.138183f, 0.196582f, 0.368725f, 0.820993f, 0.097101f, 0.837945f, 0.096098f, 0.976459f, 0.468651f, 0.976761f, 0.604846f, 0.739264f, 0.039188f, 0.282807f, 0.120197f, 0.296140f, 0.118728f, 0.317983f, 0.414263f, 0.064147f, 0.692472f, 0.566601f, 0.265390f, 0.523248f, 0.093941f, 0.575947f, 0.929296f, 0.318569f, 0.667410f, 0.131798f, 0.716327f, 0.289406f, 0.183191f, 0.586513f, 0.020108f, 0.828940f, 0.004695f, 0.677817f, 0.270008f, 0.735194f, 0.962189f, 0.248753f, 0.576157f, 0.592042f, 0.572252f, 0.223082f, 0.952749f, 0.447125f, 0.846409f, 0.699479f, 0.297437f, 0.813798f, 0.396506f, 0.881103f, 0.581273f, 0.881735f, 0.692532f, 0.725254f, 0.501324f, 0.956084f, 0.643990f, 0.423855f, 0.606393f, 0.019193f, 0.301575f, 0.660174f, 0.290078f, 0.618015f, 0.428769f, 0.135474f, 0.298282f, 0.569965f, 0.590873f, 0.574325f, 0.653201f, 0.652103f, 0.431418f, 0.896547f, 0.367562f, 0.435865f, 0.891923f, 0.806194f, 0.703889f, 0.100227f, 0.919483f, 0.714241f, 0.998847f, 0.149448f, 0.868126f, 0.162493f, 0.615560f, 0.123820f, 0.848008f, 0.807319f, 0.569101f, 0.407183f, 0.069167f, 0.697429f, 0.453543f, 0.722056f, 0.866382f, 0.975522f, 0.855803f, 0.011714f, 0.359978f, 0.729991f, 0.171630f, 0.521037f, 0.054338f, 0.199997f, 0.018522f, 0.793698f, 0.223925f, 0.345352f, 0.928081f, 0.704414f, 0.031839f, 0.164694f, 0.621478f, 0.577229f, 0.237893f, 0.934214f, 0.613966f, 0.535633f, 0.589910f, 0.730122f, 0.311945f, 0.398221f, 0.209844f, 0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f, 0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f}; + // {2, 6, 24} + std::vector k = {0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f, 0.966147f, 0.791880f, 0.675689f, 0.244889f, 0.216457f, 0.166048f, 0.922757f, 0.294077f, 0.453094f, 0.493958f, 0.778172f, 0.844235f, 0.139073f, 0.426904f, 0.842855f, 0.818033f, 0.102414f, 0.156383f, 0.304199f, 0.075359f, 0.424663f, 0.107618f, 0.568218f, 0.246557f, 0.596433f, 0.117526f, 0.975884f, 0.932561f, 0.391797f, 0.242179f, 0.250398f, 0.483394f, 0.039993f, 0.639705f, 0.408303f, 0.377407f, 0.809365f, 0.709035f, 0.954334f, 0.351936f, 0.897543f, 0.769967f, 0.357425f, 0.621665f, 0.288570f, 0.874400f, 0.112427f, 0.212434f, 0.183033f, 0.403026f, 0.745233f, 0.526907f, 0.487676f, 0.000546f, 0.425402f, 0.063554f, 0.208253f, 0.932394f, 0.215398f, 0.858338f, 0.802893f, 0.159146f, 0.605712f, 0.115662f, 0.727888f, 0.637462f, 0.811939f, 0.479385f, 0.914863f, 0.049349f, 0.292889f, 0.715053f, 0.418109f, 0.172951f, 0.107211f, 0.817339f, 0.473143f, 0.882284f, 0.733289f, 0.409726f, 0.373511f, 0.515638f, 0.889060f, 0.737279f, 0.005153f, 0.694158f, 0.919507f, 0.710456f, 0.177006f, 0.483518f, 0.140316f, 0.358995f, 0.937117f, 0.923305f, 0.282837f, 0.339631f}; + // {2, 6, 24} + std::vector v = {0.600213f, 0.963197f, 0.147801f, 0.256917f, 0.873557f, 0.491892f, 0.898961f, 0.185518f, 0.532669f, 0.326270f, 0.316543f, 0.446877f, 0.433077f, 0.357347f, 0.914971f, 0.731744f, 0.727547f, 0.289913f, 0.577709f, 0.779179f, 0.795590f, 0.344530f, 0.770873f, 0.735894f, 0.141506f, 0.865945f, 0.441321f, 0.486410f, 0.448369f, 0.567846f, 0.621169f, 0.498180f, 0.866789f, 0.627735f, 0.401428f, 0.416692f, 0.810839f, 0.348192f, 0.211455f, 0.059383f, 0.876027f, 0.918546f, 0.120120f, 0.334474f, 0.175372f, 0.115898f, 0.899867f, 0.056877f, 0.980486f, 0.096451f, 0.863471f, 0.566506f, 0.367917f, 0.342342f, 0.757364f, 0.314573f, 0.657319f, 0.517326f, 0.484966f, 0.901162f, 0.554645f, 0.826862f, 0.725574f, 0.038557f, 0.773110f, 0.216870f, 0.903150f, 0.042924f, 0.333072f, 0.099733f, 0.475589f, 0.820022f, 0.298187f, 0.150935f, 0.330267f, 0.813880f, 0.140384f, 0.227362f, 0.068852f, 0.705710f, 0.395233f, 0.310840f, 0.718626f, 0.335978f, 0.727771f, 0.815199f, 0.217663f, 0.973819f, 0.162358f, 0.290841f, 0.179795f, 0.345506f, 0.480061f, 0.522176f, 0.853606f, 0.889448f, 0.220104f, 0.622894f, 0.111496f, 0.458970f, 0.322334f, 0.316501f, 0.482584f, 0.729828f, 0.069183f, 0.879173f, 0.734814f, 0.176499f, 0.939161f, 0.506312f, 0.999809f, 0.197259f, 0.534908f, 0.290248f, 0.304174f, 0.591065f, 0.921719f, 0.805264f, 0.723941f, 0.559174f, 0.922298f, 0.492361f, 0.873832f, 0.833982f, 0.213835f, 0.771225f, 0.012171f, 0.322830f, 0.229567f, 0.506863f, 0.736853f, 0.097676f, 0.514922f, 0.938412f, 0.228647f, 0.677141f, 0.592880f, 0.010064f, 0.475826f, 0.708770f, 0.043975f, 0.879521f, 0.520081f, 0.030661f, 0.224414f, 0.953676f, 0.582320f, 0.107473f, 0.287544f, 0.456704f, 0.020950f, 0.411616f, 0.489459f, 0.243678f, 0.588639f, 0.753240f, 0.235834f, 0.620500f, 0.639622f, 0.948540f, 0.778276f, 0.848345f, 0.490420f, 0.185349f, 0.995815f, 0.129356f, 0.471457f, 0.068093f, 0.943851f, 0.964925f, 0.719389f, 0.349993f, 0.254382f, 0.265303f, 0.127294f, 0.525809f, 0.141817f, 0.316731f, 0.626706f, 0.727544f, 0.024273f, 0.430116f, 0.652125f, 0.853246f, 0.475325f, 0.969206f, 0.265633f, 0.013509f, 0.483753f, 0.256114f, 0.823718f, 0.232773f, 0.310629f, 0.791227f, 0.715143f, 0.558051f, 0.704948f, 0.418637f, 0.005310f, 0.011355f, 0.511222f, 0.083291f, 0.051075f, 0.965517f, 0.859003f, 0.152027f, 0.000664f, 0.941668f, 0.278325f, 0.185898f, 0.691508f, 0.108904f, 0.264650f, 0.975095f, 0.639463f, 0.520678f, 0.397919f, 0.774501f, 0.140957f, 0.967338f, 0.861123f, 0.617657f, 0.042906f, 0.700856f, 0.913284f, 0.524577f, 0.354225f, 0.120277f, 0.754901f, 0.885022f, 0.100252f, 0.758985f, 0.017060f, 0.967055f, 0.615058f, 0.552439f, 0.295950f, 0.929292f, 0.265906f, 0.828147f, 0.985109f, 0.783397f, 0.518990f, 0.066074f, 0.472414f, 0.438256f, 0.202796f, 0.423588f, 0.357758f, 0.163684f, 0.441374f, 0.262800f, 0.522062f, 0.035160f, 0.906231f, 0.816364f, 0.552581f, 0.851809f, 0.962395f, 0.110522f, 0.630832f, 0.997994f, 0.987889f, 0.603323f, 0.128021f, 0.583193f, 0.002065f, 0.198911f, 0.956123f, 0.330441f, 0.638390f, 0.280860f, 0.947822f, 0.728559f, 0.329651f, 0.791761f, 0.108166f, 0.392319f, 0.221218f, 0.683726f, 0.102446f, 0.397026f, 0.276650f, 0.506343f, 0.349898f, 0.706411f, 0.024577f, 0.633987f}; + // {2, 4, 72} + std::vector y = {0.532009f, 0.526025f, 0.449746f, 0.551692f, 0.407822f, 0.436275f, 0.507807f, 0.457324f, 0.530536f, 0.517111f, 0.452785f, 0.557318f, 0.397721f, 0.434161f, 0.498276f, 0.464536f, 0.528016f, 0.548671f, 0.441040f, 0.542961f, 0.418557f, 0.444397f, 0.515088f, 0.452512f, 0.462161f, 0.530536f, 0.564630f, 0.418701f, 0.669452f, 0.633554f, 0.569379f, 0.430544f, 0.456026f, 0.529795f, 0.558238f, 0.411985f, 0.664240f, 0.619959f, 0.590516f, 0.438577f, 0.471552f, 0.521718f, 0.560465f, 0.404206f, 0.663920f, 0.628819f, 0.540935f, 0.447763f, 0.615083f, 0.344791f, 0.432664f, 0.451253f, 0.460813f, 0.441267f, 0.708582f, 0.530088f, 0.623659f, 0.343547f, 0.439418f, 0.450767f, 0.460055f, 0.442001f, 0.703292f, 0.522883f, 0.617738f, 0.343160f, 0.440540f, 0.440079f, 0.459815f, 0.436860f, 0.703290f, 0.534856f, 0.536138f, 0.499439f, 0.465771f, 0.565138f, 0.391402f, 0.430258f, 0.494915f, 0.463613f, 0.532752f, 0.526358f, 0.452075f, 0.562130f, 0.402551f, 0.442784f, 0.486721f, 0.456955f, 0.547578f, 0.527342f, 0.453800f, 0.548887f, 0.418444f, 0.438968f, 0.515475f, 0.444207f, 0.475352f, 0.524010f, 0.549702f, 0.420030f, 0.656346f, 0.620729f, 0.571884f, 0.431010f, 0.453307f, 0.522210f, 0.563368f, 0.412061f, 0.657897f, 0.634999f, 0.577458f, 0.451691f, 0.473936f, 0.524285f, 0.553525f, 0.421768f, 0.662288f, 0.622833f, 0.570081f, 0.432808f, 0.625738f, 0.353159f, 0.436185f, 0.448597f, 0.459371f, 0.429822f, 0.709026f, 0.526207f, 0.630878f, 0.351036f, 0.439799f, 0.452249f, 0.456486f, 0.431906f, 0.706014f, 0.518897f, 0.629526f, 0.351482f, 0.440728f, 0.449287f, 0.451705f, 0.426815f, 0.706598f, 0.522028f, 0.537899f, 0.527199f, 0.447980f, 0.548688f, 0.410653f, 0.436181f, 0.511135f, 0.455244f, 0.534560f, 0.540045f, 0.447505f, 0.552786f, 0.413302f, 0.446360f, 0.499945f, 0.450757f, 0.531708f, 0.526097f, 0.450511f, 0.553372f, 0.401450f, 0.438186f, 0.501418f, 0.462466f, 0.469643f, 0.527539f, 0.553613f, 0.418159f, 0.659814f, 0.622731f, 0.575224f, 0.429425f, 0.463941f, 0.524481f, 0.557632f, 0.413729f, 0.657415f, 0.629157f, 0.570920f, 0.439773f, 0.479643f, 0.526773f, 0.556809f, 0.422406f, 0.670038f, 0.625300f, 0.554451f, 0.426587f, 0.630894f, 0.353011f, 0.444285f, 0.443177f, 0.448608f, 0.419312f, 0.705883f, 0.526260f, 0.631310f, 0.347563f, 0.445672f, 0.446224f, 0.448210f, 0.428481f, 0.702004f, 0.519990f, 0.626158f, 0.342802f, 0.449770f, 0.440666f, 0.453705f, 0.427492f, 0.700510f, 0.533279f, 0.526144f, 0.538202f, 0.443619f, 0.551579f, 0.407162f, 0.442426f, 0.499995f, 0.459987f, 0.525627f, 0.544718f, 0.448060f, 0.544942f, 0.415781f, 0.444198f, 0.516948f, 0.452985f, 0.521784f, 0.523083f, 0.450924f, 0.565538f, 0.392054f, 0.440702f, 0.479094f, 0.468113f, 0.473886f, 0.523677f, 0.555144f, 0.409412f, 0.664285f, 0.620163f, 0.555448f, 0.440947f, 0.459210f, 0.528829f, 0.567231f, 0.413602f, 0.672778f, 0.632467f, 0.565881f, 0.439895f, 0.480238f, 0.525127f, 0.554365f, 0.431656f, 0.658900f, 0.634358f, 0.561181f, 0.419623f, 0.646099f, 0.364754f, 0.442180f, 0.450340f, 0.441320f, 0.412523f, 0.708121f, 0.505939f, 0.641772f, 0.375478f, 0.428502f, 0.454772f, 0.439016f, 0.407773f, 0.718457f, 0.504047f, 0.628271f, 0.345239f, 0.449391f, 0.436208f, 0.448766f, 0.426444f, 0.699202f, 0.528374f, 0.489165f, 0.818278f, 0.467403f, 0.370507f, 0.572406f, 0.417942f, 0.160316f, 0.384139f, 0.497723f, 0.820329f, 0.455669f, 0.373132f, 0.568626f, 0.418602f, 0.164551f, 0.404233f, 0.488972f, 0.813399f, 0.460936f, 0.369774f, 0.580477f, 0.417018f, 0.167442f, 0.381535f, 0.603715f, 0.360599f, 0.371685f, 0.614777f, 0.440767f, 0.425124f, 0.369342f, 0.828101f, 0.584460f, 0.352249f, 0.382191f, 0.613073f, 0.431223f, 0.421802f, 0.389292f, 0.831202f, 0.590574f, 0.355658f, 0.373391f, 0.623741f, 0.432416f, 0.412097f, 0.378312f, 0.829226f, 0.365226f, 0.726961f, 0.549872f, 0.239494f, 0.496434f, 0.668542f, 0.557774f, 0.487281f, 0.361340f, 0.749156f, 0.523408f, 0.240555f, 0.493770f, 0.639516f, 0.552116f, 0.478230f, 0.367118f, 0.740114f, 0.563789f, 0.238852f, 0.498407f, 0.682064f, 0.571327f, 0.496416f, 0.480636f, 0.820258f, 0.464776f, 0.362168f, 0.567256f, 0.417842f, 0.161815f, 0.387104f, 0.486998f, 0.821507f, 0.467362f, 0.377934f, 0.569593f, 0.418367f, 0.156778f, 0.390179f, 0.461449f, 0.823726f, 0.471401f, 0.361646f, 0.563554f, 0.418609f, 0.154999f, 0.379696f, 0.565916f, 0.345293f, 0.392969f, 0.612305f, 0.418858f, 0.416238f, 0.410985f, 0.833515f, 0.552881f, 0.338985f, 0.394863f, 0.597100f, 0.422296f, 0.401025f, 0.427810f, 0.831702f, 0.558983f, 0.339943f, 0.393544f, 0.583418f, 0.432193f, 0.405729f, 0.426401f, 0.830305f, 0.362801f, 0.731181f, 0.546338f, 0.247016f, 0.499389f, 0.662441f, 0.544727f, 0.486631f, 0.355514f, 0.726998f, 0.518056f, 0.249475f, 0.492155f, 0.643678f, 0.531052f, 0.481617f, 0.370308f, 0.743741f, 0.562172f, 0.233361f, 0.498431f, 0.679567f, 0.580747f, 0.494199f, 0.481097f, 0.817782f, 0.461707f, 0.369188f, 0.573825f, 0.419752f, 0.161614f, 0.386708f, 0.472911f, 0.822003f, 0.473412f, 0.375830f, 0.569966f, 0.422158f, 0.149228f, 0.380008f, 0.454662f, 0.818956f, 0.465984f, 0.370169f, 0.575537f, 0.423344f, 0.153818f, 0.375466f, 0.572526f, 0.348075f, 0.380718f, 0.641409f, 0.417012f, 0.407621f, 0.389074f, 0.834251f, 0.581008f, 0.348183f, 0.383659f, 0.608061f, 0.435032f, 0.422240f, 0.393710f, 0.832528f, 0.600530f, 0.360439f, 0.371006f, 0.609018f, 0.441082f, 0.416286f, 0.374920f, 0.825853f, 0.364932f, 0.727047f, 0.540001f, 0.246375f, 0.501524f, 0.656266f, 0.541761f, 0.482865f, 0.360322f, 0.752650f, 0.542120f, 0.239561f, 0.491207f, 0.663446f, 0.566643f, 0.491988f, 0.364532f, 0.737402f, 0.546869f, 0.240953f, 0.497072f, 0.664793f, 0.558528f, 0.488182f, 0.490592f, 0.819727f, 0.468739f, 0.379671f, 0.572959f, 0.422399f, 0.152699f, 0.387445f, 0.462308f, 0.822644f, 0.463886f, 0.374320f, 0.569615f, 0.423238f, 0.152603f, 0.387850f, 0.451896f, 0.818576f, 0.449904f, 0.362889f, 0.573917f, 0.421849f, 0.165145f, 0.390440f, 0.565044f, 0.343397f, 0.395512f, 0.584043f, 0.431062f, 0.417783f, 0.421165f, 0.830938f, 0.583998f, 0.354061f, 0.374016f, 0.633981f, 0.424457f, 0.404069f, 0.381920f, 0.829920f, 0.568315f, 0.347357f, 0.386911f, 0.624227f, 0.418162f, 0.411256f, 0.400332f, 0.832994f, 0.370475f, 0.739716f, 0.551429f, 0.234114f, 0.499500f, 0.665245f, 0.570648f, 0.485298f, 0.364035f, 0.756092f, 0.542251f, 0.238706f, 0.495463f, 0.659518f, 0.567976f, 0.489204f, 0.368942f, 0.756397f, 0.548083f, 0.231854f, 0.496617f, 0.659726f, 0.578330f, 0.484921f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + ASSERT_EQ(y.size(), batch_size * q_num_heads * q_sequence_length * v_head_size); + + RunTest3D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), + -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, std::vector(), std::vector(), std::vector(), + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + TEST(AttentionTest, Attention4DGqaAttnMask) { int batch_size = 2; // Q.shape[0] int q_num_heads = 9; // Q.shape[1] @@ -847,7 +880,7 @@ TEST(AttentionTest, Attention4DGqaAttnMask) { // {4, 6} std::vector m = {0.230571f, 0.268709f, 0.800256f, 0.955568f, 0.316550f, 0.826805f, 0.103991f, 0.633982f, 0.751032f, 0.155978f, 0.426002f, 0.892707f, 0.103578f, 0.018096f, 0.590585f, 0.435532f, 0.798689f, 0.923456f, 0.299154f, 0.388404f, 0.486272f, 0.588151f, 0.983854f, 0.697330f}; // {2, 9, 4, 8} - std::vector y = {0.641842f, 0.667534f, 0.339592f, 0.480609f, 0.537525f, 0.340368f, 0.752882f, 0.387601f, 0.686814f, 0.643437f, 0.324983f, 0.468788f, 0.539061f, 0.319610f, 0.754181f, 0.373093f, 0.702380f, 0.693136f, 0.318406f, 0.456714f, 0.540838f, 0.315487f, 0.718291f, 0.311025f, 0.681769f, 0.670603f, 0.329705f, 0.456661f, 0.573902f, 0.337385f, 0.700597f, 0.333385f, 0.508992f, 0.253478f, 0.553979f, 0.466355f, 0.398637f, 0.412493f, 0.495810f, 0.677675f, 0.521609f, 0.278997f, 0.564189f, 0.434417f, 0.448085f, 0.467205f, 0.567856f, 0.664713f, 0.490146f, 0.261321f, 0.560582f, 0.424598f, 0.450318f, 0.467336f, 0.520983f, 0.720798f, 0.516095f, 0.264495f, 0.577940f, 0.475340f, 0.444145f, 0.477909f, 0.485663f, 0.672846f, 0.499389f, 0.402198f, 0.520218f, 0.550550f, 0.481065f, 0.730488f, 0.492535f, 0.392315f, 0.436722f, 0.398514f, 0.497457f, 0.502270f, 0.520993f, 0.730472f, 0.565429f, 0.380282f, 0.461226f, 0.392968f, 0.536035f, 0.505191f, 0.446570f, 0.751253f, 0.478584f, 0.389036f, 0.423738f, 0.443828f, 0.554323f, 0.462607f, 0.476656f, 0.733228f, 0.482219f, 0.411910f, 0.620556f, 0.662948f, 0.349409f, 0.482541f, 0.537250f, 0.351544f, 0.734285f, 0.397172f, 0.689500f, 0.637077f, 0.320710f, 0.470914f, 0.526307f, 0.312878f, 0.775762f, 0.384457f, 0.696615f, 0.681034f, 0.324383f, 0.459632f, 0.539497f, 0.317950f, 0.709736f, 0.320698f, 0.671696f, 0.676830f, 0.332387f, 0.453234f, 0.578648f, 0.345084f, 0.685369f, 0.328092f, 0.520830f, 0.251061f, 0.562824f, 0.469184f, 0.393635f, 0.405203f, 0.493565f, 0.668713f, 0.541328f, 0.282797f, 0.577903f, 0.434065f, 0.444664f, 0.460403f, 0.572628f, 0.646402f, 0.493508f, 0.265246f, 0.572078f, 0.418658f, 0.464491f, 0.483746f, 0.516536f, 0.724847f, 0.503705f, 0.270557f, 0.577678f, 0.465114f, 0.468430f, 0.508402f, 0.489087f, 0.689442f, 0.500042f, 0.410507f, 0.521381f, 0.553244f, 0.459062f, 0.719706f, 0.476571f, 0.395052f, 0.429926f, 0.408857f, 0.507006f, 0.493937f, 0.529878f, 0.728873f, 0.571495f, 0.376256f, 0.453676f, 0.380482f, 0.526100f, 0.496696f, 0.457383f, 0.761933f, 0.486657f, 0.396608f, 0.435748f, 0.432822f, 0.531763f, 0.482255f, 0.477046f, 0.726381f, 0.487480f, 0.416572f, 0.626676f, 0.683736f, 0.340657f, 0.475002f, 0.549981f, 0.353311f, 0.740157f, 0.378827f, 0.681403f, 0.636622f, 0.324593f, 0.469088f, 0.537323f, 0.321344f, 0.762506f, 0.384239f, 0.693108f, 0.683351f, 0.329873f, 0.460504f, 0.555115f, 0.325379f, 0.694659f, 0.316422f, 0.677285f, 0.670298f, 0.329724f, 0.456327f, 0.567533f, 0.337560f, 0.701396f, 0.336191f, 0.515940f, 0.251020f, 0.562035f, 0.442479f, 0.405802f, 0.410828f, 0.519841f, 0.686781f, 0.522057f, 0.285013f, 0.562761f, 0.453472f, 0.451971f, 0.481286f, 0.558322f, 0.649971f, 0.486787f, 0.258011f, 0.557963f, 0.426743f, 0.442028f, 0.457034f, 0.510534f, 0.724945f, 0.498901f, 0.272090f, 0.572650f, 0.467930f, 0.465335f, 0.506181f, 0.484559f, 0.690090f, 0.499525f, 0.398443f, 0.522291f, 0.550620f, 0.465209f, 0.731897f, 0.484389f, 0.388997f, 0.411109f, 0.420719f, 0.523354f, 0.478677f, 0.522513f, 0.723052f, 0.587358f, 0.350775f, 0.450881f, 0.384685f, 0.527140f, 0.502089f, 0.438660f, 0.749234f, 0.493312f, 0.377459f, 0.425945f, 0.432397f, 0.544111f, 0.466484f, 0.488077f, 0.738712f, 0.493642f, 0.412262f, 0.565934f, 0.795554f, 0.527262f, 0.295395f, 0.394937f, 0.326235f, 0.457519f, 0.454071f, 0.511390f, 0.753500f, 0.500815f, 0.303925f, 0.403792f, 0.343750f, 0.516333f, 0.463035f, 0.491925f, 0.753119f, 0.503555f, 0.310489f, 0.373396f, 0.334562f, 0.526486f, 0.470500f, 0.495985f, 0.733211f, 0.532951f, 0.342292f, 0.346065f, 0.355272f, 0.479542f, 0.509107f, 0.379088f, 0.582413f, 0.414383f, 0.571800f, 0.613176f, 0.687631f, 0.185596f, 0.656867f, 0.390452f, 0.532452f, 0.407547f, 0.564799f, 0.606499f, 0.653258f, 0.176547f, 0.698038f, 0.410398f, 0.604586f, 0.442972f, 0.497533f, 0.595085f, 0.732265f, 0.187201f, 0.663169f, 0.448716f, 0.590302f, 0.411879f, 0.518449f, 0.636722f, 0.695827f, 0.154292f, 0.666828f, 0.458054f, 0.608582f, 0.430376f, 0.316371f, 0.547620f, 0.542559f, 0.542043f, 0.556297f, 0.468371f, 0.559154f, 0.465195f, 0.344099f, 0.482571f, 0.527115f, 0.527529f, 0.616254f, 0.494566f, 0.605555f, 0.432360f, 0.382197f, 0.466678f, 0.556031f, 0.459313f, 0.588575f, 0.532798f, 0.597684f, 0.412305f, 0.393400f, 0.462773f, 0.491821f, 0.483189f, 0.593919f, 0.569241f, 0.793791f, 0.532988f, 0.300026f, 0.393843f, 0.327085f, 0.448199f, 0.457416f, 0.493302f, 0.725336f, 0.512066f, 0.327500f, 0.404238f, 0.351704f, 0.507818f, 0.477990f, 0.479548f, 0.756083f, 0.511730f, 0.309729f, 0.366024f, 0.338031f, 0.503335f, 0.472352f, 0.473026f, 0.696816f, 0.543129f, 0.374608f, 0.335432f, 0.360978f, 0.486364f, 0.531799f, 0.380422f, 0.599984f, 0.413640f, 0.564090f, 0.607571f, 0.708289f, 0.187551f, 0.671587f, 0.381058f, 0.550543f, 0.422336f, 0.556663f, 0.599418f, 0.666369f, 0.182365f, 0.678737f, 0.423800f, 0.600509f, 0.437094f, 0.494968f, 0.603340f, 0.727226f, 0.179659f, 0.667114f, 0.464399f, 0.563292f, 0.399716f, 0.529198f, 0.655782f, 0.666396f, 0.143497f, 0.659062f, 0.453034f, 0.596627f, 0.417365f, 0.314318f, 0.554269f, 0.518967f, 0.550250f, 0.556252f, 0.494918f, 0.587774f, 0.467566f, 0.350222f, 0.481994f, 0.538857f, 0.525631f, 0.605359f, 0.497486f, 0.608472f, 0.429145f, 0.384532f, 0.466790f, 0.554752f, 0.457698f, 0.586510f, 0.548577f, 0.604359f, 0.398097f, 0.414429f, 0.448200f, 0.485158f, 0.461395f, 0.593015f, 0.563470f, 0.796184f, 0.532783f, 0.293209f, 0.408910f, 0.327450f, 0.438028f, 0.447011f, 0.493041f, 0.739603f, 0.496957f, 0.311881f, 0.389768f, 0.352503f, 0.530113f, 0.476738f, 0.484897f, 0.752985f, 0.511921f, 0.312174f, 0.370408f, 0.339775f, 0.504061f, 0.473793f, 0.487978f, 0.714687f, 0.538817f, 0.358426f, 0.348908f, 0.355820f, 0.481380f, 0.516214f, 0.370872f, 0.602034f, 0.400225f, 0.611090f, 0.630508f, 0.662527f, 0.162489f, 0.658299f, 0.378734f, 0.537283f, 0.412214f, 0.570032f, 0.601452f, 0.653569f, 0.179932f, 0.693105f, 0.411981f, 0.605715f, 0.448022f, 0.481469f, 0.585099f, 0.748463f, 0.195177f, 0.671915f, 0.442141f, 0.581881f, 0.393362f, 0.555388f, 0.650764f, 0.665937f, 0.141141f, 0.675100f, 0.448606f, 0.605061f, 0.412183f, 0.312673f, 0.559178f, 0.530440f, 0.538275f, 0.546820f, 0.494936f, 0.585982f, 0.469875f, 0.355291f, 0.474437f, 0.542980f, 0.518181f, 0.609491f, 0.522046f, 0.618936f, 0.412090f, 0.410711f, 0.452217f, 0.540284f, 0.444109f, 0.585510f, 0.570158f, 0.614413f, 0.415425f, 0.410005f, 0.441791f, 0.491080f, 0.466021f, 0.595833f}; + std::vector y = {0.641842f, 0.667534f, 0.339592f, 0.480609f, 0.537525f, 0.340368f, 0.752882f, 0.387601f, 0.686814f, 0.643437f, 0.324983f, 0.468788f, 0.539061f, 0.319610f, 0.754181f, 0.373093f, 0.702380f, 0.693136f, 0.318406f, 0.456714f, 0.540838f, 0.315487f, 0.718291f, 0.311025f, 0.681769f, 0.670603f, 0.329705f, 0.456661f, 0.573902f, 0.337385f, 0.700597f, 0.333385f, 0.644472f, 0.666279f, 0.336558f, 0.478260f, 0.534820f, 0.338286f, 0.756443f, 0.387184f, 0.674255f, 0.645509f, 0.327427f, 0.465534f, 0.543598f, 0.328256f, 0.743604f, 0.373978f, 0.689753f, 0.687485f, 0.332246f, 0.457085f, 0.565540f, 0.331625f, 0.677863f, 0.308191f, 0.663033f, 0.669169f, 0.333832f, 0.452516f, 0.576569f, 0.348823f, 0.685447f, 0.338196f, 0.613061f, 0.681689f, 0.345384f, 0.474784f, 0.541609f, 0.357958f, 0.728217f, 0.383408f, 0.680108f, 0.637886f, 0.329455f, 0.469504f, 0.544973f, 0.325193f, 0.745572f, 0.378169f, 0.695405f, 0.687321f, 0.323229f, 0.456101f, 0.553544f, 0.323743f, 0.706057f, 0.314785f, 0.672814f, 0.678842f, 0.323628f, 0.449345f, 0.572724f, 0.342071f, 0.707722f, 0.332714f, 0.512254f, 0.252087f, 0.555774f, 0.456582f, 0.393340f, 0.400567f, 0.501655f, 0.680466f, 0.530775f, 0.288611f, 0.570275f, 0.444357f, 0.454871f, 0.480588f, 0.567893f, 0.645871f, 0.491847f, 0.262209f, 0.561930f, 0.418081f, 0.444398f, 0.456345f, 0.519658f, 0.722565f, 0.523232f, 0.267034f, 0.591659f, 0.459565f, 0.462164f, 0.494775f, 0.497558f, 0.678628f, 0.520830f, 0.251061f, 0.562824f, 0.469184f, 0.393635f, 0.405203f, 0.493565f, 0.668713f, 0.541328f, 0.282797f, 0.577903f, 0.434065f, 0.444664f, 0.460403f, 0.572628f, 0.646402f, 0.493508f, 0.265246f, 0.572078f, 0.418658f, 0.464491f, 0.483746f, 0.516536f, 0.724847f, 0.503705f, 0.270557f, 0.577678f, 0.465114f, 0.468430f, 0.508402f, 0.489087f, 0.689442f, 0.513034f, 0.252153f, 0.561841f, 0.455825f, 0.411518f, 0.424734f, 0.508095f, 0.683202f, 0.537475f, 0.278680f, 0.572605f, 0.449901f, 0.433722f, 0.452424f, 0.554372f, 0.643199f, 0.503808f, 0.259719f, 0.571011f, 0.415224f, 0.442363f, 0.450636f, 0.525191f, 0.716156f, 0.524579f, 0.263175f, 0.588806f, 0.462952f, 0.450874f, 0.480435f, 0.495070f, 0.675950f, 0.503113f, 0.409947f, 0.538941f, 0.550010f, 0.457564f, 0.729741f, 0.472483f, 0.384586f, 0.421666f, 0.416784f, 0.522405f, 0.484472f, 0.519795f, 0.728113f, 0.570887f, 0.363251f, 0.462182f, 0.372738f, 0.510951f, 0.511798f, 0.446353f, 0.754695f, 0.485592f, 0.397135f, 0.421437f, 0.447040f, 0.546262f, 0.462919f, 0.473860f, 0.726421f, 0.479062f, 0.420641f, 0.498228f, 0.402912f, 0.524895f, 0.548811f, 0.462668f, 0.729601f, 0.480759f, 0.390396f, 0.421638f, 0.418506f, 0.518644f, 0.484993f, 0.512452f, 0.724489f, 0.562537f, 0.370564f, 0.461864f, 0.376424f, 0.511195f, 0.510163f, 0.461531f, 0.755198f, 0.491549f, 0.400847f, 0.425338f, 0.456035f, 0.553542f, 0.466468f, 0.482400f, 0.722062f, 0.483532f, 0.415135f, 0.499525f, 0.398443f, 0.522291f, 0.550620f, 0.465209f, 0.731897f, 0.484389f, 0.388997f, 0.411109f, 0.420719f, 0.523354f, 0.478677f, 0.522513f, 0.723052f, 0.587358f, 0.350775f, 0.450881f, 0.384685f, 0.527140f, 0.502089f, 0.438660f, 0.749234f, 0.493312f, 0.377459f, 0.425945f, 0.432397f, 0.544111f, 0.466484f, 0.488077f, 0.738712f, 0.493642f, 0.412262f, 0.565934f, 0.795554f, 0.527262f, 0.295395f, 0.394937f, 0.326235f, 0.457519f, 0.454071f, 0.511390f, 0.753500f, 0.500815f, 0.303925f, 0.403792f, 0.343750f, 0.516333f, 0.463035f, 0.491925f, 0.753119f, 0.503555f, 0.310489f, 0.373396f, 0.334562f, 0.526486f, 0.470500f, 0.495985f, 0.733211f, 0.532951f, 0.342292f, 0.346065f, 0.355272f, 0.479542f, 0.509107f, 0.560779f, 0.795626f, 0.527843f, 0.292198f, 0.403399f, 0.328103f, 0.449548f, 0.449270f, 0.492632f, 0.741337f, 0.501964f, 0.308729f, 0.404425f, 0.353946f, 0.510715f, 0.469292f, 0.498506f, 0.749246f, 0.510938f, 0.317603f, 0.377607f, 0.333171f, 0.516589f, 0.472113f, 0.494030f, 0.738331f, 0.525273f, 0.334388f, 0.351797f, 0.349013f, 0.492978f, 0.499192f, 0.558701f, 0.785575f, 0.541472f, 0.309741f, 0.379566f, 0.336180f, 0.433460f, 0.471779f, 0.500494f, 0.748997f, 0.495158f, 0.302537f, 0.401868f, 0.348977f, 0.525071f, 0.465493f, 0.496427f, 0.763380f, 0.504640f, 0.303037f, 0.375539f, 0.332025f, 0.517142f, 0.464096f, 0.466789f, 0.731320f, 0.529262f, 0.338950f, 0.329005f, 0.361720f, 0.481664f, 0.514476f, 0.356477f, 0.623874f, 0.420893f, 0.592125f, 0.610336f, 0.687956f, 0.174269f, 0.652548f, 0.366057f, 0.567382f, 0.428770f, 0.553226f, 0.582617f, 0.683498f, 0.188604f, 0.695704f, 0.406930f, 0.625170f, 0.441775f, 0.499327f, 0.590722f, 0.740689f, 0.180721f, 0.681143f, 0.430954f, 0.584531f, 0.412720f, 0.532459f, 0.630830f, 0.690216f, 0.161882f, 0.663851f, 0.380422f, 0.599984f, 0.413640f, 0.564090f, 0.607571f, 0.708289f, 0.187551f, 0.671587f, 0.381058f, 0.550543f, 0.422336f, 0.556663f, 0.599418f, 0.666369f, 0.182365f, 0.678737f, 0.423800f, 0.600509f, 0.437094f, 0.494968f, 0.603340f, 0.727226f, 0.179659f, 0.667114f, 0.464399f, 0.563292f, 0.399716f, 0.529198f, 0.655782f, 0.666396f, 0.143497f, 0.659062f, 0.365268f, 0.611770f, 0.413907f, 0.600775f, 0.622849f, 0.667798f, 0.164152f, 0.647839f, 0.377540f, 0.543255f, 0.401769f, 0.588162f, 0.610896f, 0.645976f, 0.172500f, 0.695675f, 0.428349f, 0.590245f, 0.429343f, 0.497694f, 0.606978f, 0.727059f, 0.182826f, 0.671502f, 0.466759f, 0.580932f, 0.396764f, 0.527984f, 0.655065f, 0.677027f, 0.138356f, 0.672848f, 0.431113f, 0.593599f, 0.391529f, 0.327778f, 0.551802f, 0.526872f, 0.512055f, 0.547473f, 0.461591f, 0.564565f, 0.469932f, 0.335454f, 0.493299f, 0.536959f, 0.537769f, 0.611109f, 0.505296f, 0.606927f, 0.414343f, 0.395585f, 0.462205f, 0.538029f, 0.450814f, 0.585742f, 0.550355f, 0.606479f, 0.419783f, 0.396625f, 0.449703f, 0.500831f, 0.464506f, 0.594653f, 0.460993f, 0.609826f, 0.424563f, 0.322395f, 0.546231f, 0.537700f, 0.541169f, 0.555672f, 0.479953f, 0.573210f, 0.449011f, 0.356276f, 0.482535f, 0.523785f, 0.516393f, 0.605958f, 0.473948f, 0.587667f, 0.412118f, 0.378344f, 0.472903f, 0.540161f, 0.445341f, 0.585184f, 0.561693f, 0.609513f, 0.394200f, 0.418769f, 0.444939f, 0.478136f, 0.458334f, 0.591187f, 0.448606f, 0.605061f, 0.412183f, 0.312673f, 0.559178f, 0.530440f, 0.538275f, 0.546820f, 0.494936f, 0.585982f, 0.469875f, 0.355291f, 0.474437f, 0.542980f, 0.518181f, 0.609491f, 0.522046f, 0.618936f, 0.412090f, 0.410711f, 0.452217f, 0.540284f, 0.444109f, 0.585510f, 0.570158f, 0.614413f, 0.415425f, 0.410005f, 0.441791f, 0.491080f, 0.466021f, 0.595833f}; ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); @@ -886,7 +919,7 @@ TEST(AttentionTest, Attention4DGqaWithPastAndPresent) { // {2, 3, 12, 8} std::vector past_value = {0.546950f, 0.293617f, 0.968204f, 0.226196f, 0.015738f, 0.325855f, 0.502509f, 0.028363f, 0.559248f, 0.874283f, 0.704732f, 0.622968f, 0.955962f, 0.958279f, 0.824266f, 0.607742f, 0.487765f, 0.013316f, 0.606262f, 0.989088f, 0.818101f, 0.340605f, 0.152047f, 0.784059f, 0.743938f, 0.967047f, 0.874842f, 0.555663f, 0.101284f, 0.483501f, 0.313695f, 0.512408f, 0.301702f, 0.861823f, 0.844327f, 0.315465f, 0.599581f, 0.430181f, 0.909093f, 0.187361f, 0.697728f, 0.970375f, 0.175276f, 0.201966f, 0.693723f, 0.779154f, 0.490549f, 0.609686f, 0.212682f, 0.476614f, 0.112072f, 0.321422f, 0.284780f, 0.444625f, 0.930126f, 0.181268f, 0.401388f, 0.615597f, 0.946557f, 0.133148f, 0.917877f, 0.081054f, 0.480741f, 0.454590f, 0.209603f, 0.347460f, 0.454165f, 0.865211f, 0.955064f, 0.518926f, 0.870100f, 0.608172f, 0.349087f, 0.194194f, 0.413135f, 0.522824f, 0.044443f, 0.145841f, 0.600184f, 0.225002f, 0.837326f, 0.326942f, 0.104834f, 0.083531f, 0.937123f, 0.118020f, 0.140910f, 0.862666f, 0.254288f, 0.665951f, 0.816726f, 0.607181f, 0.957489f, 0.708883f, 0.112752f, 0.558410f, 0.718186f, 0.801957f, 0.026321f, 0.718879f, 0.825681f, 0.746834f, 0.512349f, 0.458021f, 0.549419f, 0.704644f, 0.922914f, 0.617035f, 0.887834f, 0.701257f, 0.068336f, 0.500828f, 0.286486f, 0.285175f, 0.355928f, 0.314733f, 0.578610f, 0.683601f, 0.268749f, 0.129763f, 0.058809f, 0.575753f, 0.186130f, 0.009248f, 0.927753f, 0.537140f, 0.092448f, 0.842921f, 0.983203f, 0.448601f, 0.042490f, 0.117546f, 0.381654f, 0.885523f, 0.148039f, 0.823990f, 0.014976f, 0.457389f, 0.644397f, 0.060379f, 0.614763f, 0.944404f, 0.160260f, 0.729611f, 0.609094f, 0.185116f, 0.006203f, 0.009284f, 0.532092f, 0.942779f, 0.644299f, 0.714300f, 0.493865f, 0.581889f, 0.126368f, 0.876821f, 0.760793f, 0.998199f, 0.297723f, 0.227018f, 0.125162f, 0.964210f, 0.780885f, 0.166325f, 0.552686f, 0.413768f, 0.151486f, 0.162073f, 0.963470f, 0.304964f, 0.941439f, 0.075611f, 0.460803f, 0.129619f, 0.004787f, 0.553766f, 0.113894f, 0.722025f, 0.698116f, 0.176333f, 0.941742f, 0.721043f, 0.297970f, 0.709234f, 0.731930f, 0.342226f, 0.375589f, 0.359107f, 0.616618f, 0.900410f, 0.173193f, 0.875200f, 0.027653f, 0.660339f, 0.414439f, 0.791282f, 0.721198f, 0.480108f, 0.643864f, 0.501773f, 0.811518f, 0.476084f, 0.523156f, 0.250521f, 0.605043f, 0.302905f, 0.577284f, 0.169678f, 0.159469f, 0.417030f, 0.426820f, 0.268109f, 0.131597f, 0.039211f, 0.025232f, 0.271550f, 0.461853f, 0.726243f, 0.474872f, 0.904051f, 0.035220f, 0.180661f, 0.338515f, 0.577496f, 0.852736f, 0.350202f, 0.267989f, 0.061889f, 0.821303f, 0.379666f, 0.571550f, 0.983555f, 0.001595f, 0.145450f, 0.779111f, 0.805128f, 0.769247f, 0.536999f, 0.978857f, 0.396185f, 0.601944f, 0.063369f, 0.409857f, 0.722500f, 0.238739f, 0.943828f, 0.686783f, 0.287575f, 0.768999f, 0.083165f, 0.974774f, 0.049285f, 0.933456f, 0.252854f, 0.757824f, 0.000074f, 0.254240f, 0.749101f, 0.532336f, 0.114952f, 0.393630f, 0.375549f, 0.568162f, 0.667977f, 0.840830f, 0.497231f, 0.392022f, 0.143977f, 0.804823f, 0.713370f, 0.408677f, 0.518432f, 0.665183f, 0.164806f, 0.027198f, 0.317504f, 0.595585f, 0.486606f, 0.692555f, 0.819690f, 0.488442f, 0.134267f, 0.850628f, 0.574990f, 0.739937f, 0.704665f, 0.968212f, 0.295307f, 0.705307f, 0.365676f, 0.395411f, 0.230595f, 0.344010f, 0.948297f, 0.292571f, 0.245991f, 0.583138f, 0.258036f, 0.473386f, 0.834176f, 0.230400f, 0.426691f, 0.610490f, 0.545629f, 0.974723f, 0.680370f, 0.739946f, 0.966956f, 0.414438f, 0.355380f, 0.043862f, 0.184204f, 0.237190f, 0.183504f, 0.754784f, 0.535883f, 0.667634f, 0.820462f, 0.230774f, 0.325924f, 0.708360f, 0.392759f, 0.029271f, 0.434955f, 0.908273f, 0.409021f, 0.332249f, 0.989525f, 0.644416f, 0.365998f, 0.102020f, 0.787849f, 0.708075f, 0.921916f, 0.217276f, 0.114924f, 0.724073f, 0.203396f, 0.176104f, 0.319807f, 0.816825f, 0.539537f, 0.045850f, 0.463895f, 0.683980f, 0.538368f, 0.572450f, 0.224777f, 0.847739f, 0.561399f, 0.713246f, 0.981864f, 0.428199f, 0.881067f, 0.007281f, 0.033407f, 0.590280f, 0.311449f, 0.248277f, 0.277935f, 0.318403f, 0.728948f, 0.569196f, 0.789036f, 0.830197f, 0.842935f, 0.414644f, 0.421273f, 0.926266f, 0.661764f, 0.080467f, 0.542187f, 0.356007f, 0.987435f, 0.013655f, 0.612181f, 0.723623f, 0.288907f, 0.973642f, 0.859537f, 0.915653f, 0.019232f, 0.569872f, 0.294650f, 0.849029f, 0.632850f, 0.538877f, 0.114588f, 0.540223f, 0.631904f, 0.955912f, 0.585051f, 0.967401f, 0.961606f, 0.650200f, 0.505908f, 0.466022f, 0.890379f, 0.028257f, 0.113808f, 0.102072f, 0.756935f, 0.339651f, 0.637969f, 0.603783f, 0.385828f, 0.531568f, 0.645139f, 0.940950f, 0.575634f, 0.614367f, 0.067856f, 0.952216f, 0.528082f, 0.801273f, 0.050291f, 0.420910f, 0.256975f, 0.266976f, 0.791454f, 0.623867f, 0.439745f, 0.010586f, 0.964928f, 0.962023f, 0.217552f, 0.041346f, 0.530199f, 0.951411f, 0.910396f, 0.584663f, 0.303549f, 0.329961f, 0.897914f, 0.491784f, 0.131116f, 0.248425f, 0.276795f, 0.123547f, 0.463044f, 0.916051f, 0.668783f, 0.072474f, 0.005495f, 0.276248f, 0.362693f, 0.776750f, 0.967006f, 0.387567f, 0.686690f, 0.994902f, 0.745667f, 0.636190f, 0.078075f, 0.323215f, 0.913392f, 0.201005f, 0.843590f, 0.696324f, 0.366324f, 0.529174f, 0.542806f, 0.714054f, 0.516556f, 0.133076f, 0.773455f, 0.406273f, 0.963094f, 0.283514f, 0.263079f, 0.333507f, 0.572317f, 0.894870f, 0.176282f, 0.279679f, 0.581680f, 0.454334f, 0.447323f, 0.820734f, 0.923878f, 0.481307f, 0.687352f, 0.801059f, 0.518366f, 0.294316f, 0.638085f, 0.585109f, 0.901563f, 0.052407f, 0.910131f, 0.534432f, 0.015676f, 0.344702f, 0.724334f, 0.488433f, 0.980159f, 0.422610f, 0.326635f, 0.821672f, 0.547907f, 0.682327f, 0.805702f, 0.671428f, 0.422408f, 0.124796f, 0.580248f, 0.897433f, 0.418892f, 0.910725f, 0.503528f, 0.620842f, 0.832989f, 0.564597f, 0.090969f, 0.980979f, 0.245849f, 0.710505f, 0.505113f, 0.478773f, 0.243941f, 0.722151f, 0.112788f, 0.990453f, 0.845374f, 0.534509f, 0.424553f, 0.286465f, 0.501591f, 0.879417f, 0.275006f, 0.500537f, 0.234550f, 0.337149f, 0.190261f, 0.990539f, 0.571497f, 0.732815f, 0.098250f, 0.366118f, 0.892640f, 0.084438f, 0.165483f, 0.625418f, 0.622789f, 0.838227f, 0.935493f, 0.141986f, 0.259374f, 0.427461f, 0.000903f, 0.069814f, 0.226491f, 0.481102f, 0.251523f, 0.876682f, 0.324273f, 0.924623f, 0.974787f, 0.449862f, 0.227129f, 0.291666f, 0.776334f, 0.273350f, 0.380583f, 0.478576f, 0.575111f, 0.996100f, 0.232210f, 0.353424f, 0.262891f, 0.361113f, 0.100805f, 0.359810f, 0.887865f, 0.298590f, 0.371935f}; // {2, 9, 4, 8} - std::vector y = {0.544462f, 0.617844f, 0.506335f, 0.473482f, 0.606855f, 0.423464f, 0.544771f, 0.450451f, 0.524249f, 0.627160f, 0.497201f, 0.440288f, 0.619110f, 0.437084f, 0.563680f, 0.440037f, 0.516736f, 0.577726f, 0.523888f, 0.493471f, 0.594122f, 0.433401f, 0.585942f, 0.457686f, 0.528512f, 0.604578f, 0.472106f, 0.471486f, 0.600445f, 0.446256f, 0.622393f, 0.435442f, 0.440810f, 0.437705f, 0.476508f, 0.320820f, 0.605191f, 0.640150f, 0.306216f, 0.610947f, 0.485794f, 0.448216f, 0.485639f, 0.323744f, 0.594446f, 0.646597f, 0.321742f, 0.605751f, 0.501858f, 0.445502f, 0.487899f, 0.384660f, 0.597134f, 0.616430f, 0.331401f, 0.566459f, 0.502522f, 0.409965f, 0.526639f, 0.348601f, 0.565200f, 0.586558f, 0.325044f, 0.603422f, 0.450250f, 0.368009f, 0.550911f, 0.460338f, 0.523907f, 0.508816f, 0.575624f, 0.426601f, 0.472310f, 0.372844f, 0.517852f, 0.431688f, 0.551555f, 0.527657f, 0.600578f, 0.473069f, 0.456633f, 0.442035f, 0.539875f, 0.437863f, 0.540202f, 0.499608f, 0.556470f, 0.419831f, 0.463081f, 0.416724f, 0.526389f, 0.458654f, 0.540120f, 0.551554f, 0.569399f, 0.447102f, 0.534296f, 0.597655f, 0.509699f, 0.487167f, 0.607438f, 0.426383f, 0.522794f, 0.458435f, 0.510147f, 0.622761f, 0.501724f, 0.453386f, 0.629671f, 0.434103f, 0.582477f, 0.437681f, 0.520031f, 0.568543f, 0.525216f, 0.490370f, 0.571745f, 0.428629f, 0.572995f, 0.460086f, 0.533607f, 0.614962f, 0.474130f, 0.456345f, 0.576467f, 0.448127f, 0.599211f, 0.432252f, 0.447842f, 0.430169f, 0.480055f, 0.320521f, 0.590915f, 0.627003f, 0.314551f, 0.609320f, 0.499216f, 0.438828f, 0.485519f, 0.322134f, 0.586364f, 0.645824f, 0.326481f, 0.596989f, 0.496362f, 0.442741f, 0.492120f, 0.366111f, 0.601604f, 0.615566f, 0.326354f, 0.567173f, 0.496946f, 0.422179f, 0.533144f, 0.342588f, 0.590482f, 0.605923f, 0.318055f, 0.610401f, 0.452598f, 0.361594f, 0.550919f, 0.455099f, 0.530404f, 0.519313f, 0.588655f, 0.431890f, 0.464325f, 0.389636f, 0.515359f, 0.429087f, 0.540767f, 0.518376f, 0.586627f, 0.471074f, 0.458527f, 0.422216f, 0.537762f, 0.434123f, 0.550956f, 0.507704f, 0.564828f, 0.421548f, 0.463044f, 0.407985f, 0.523093f, 0.473684f, 0.542663f, 0.551348f, 0.576783f, 0.448743f, 0.546208f, 0.621128f, 0.501647f, 0.468191f, 0.612298f, 0.425183f, 0.549241f, 0.447622f, 0.519355f, 0.619636f, 0.487775f, 0.444259f, 0.625749f, 0.430264f, 0.584338f, 0.436887f, 0.521021f, 0.572716f, 0.522539f, 0.486440f, 0.581317f, 0.429079f, 0.579691f, 0.455426f, 0.526431f, 0.604615f, 0.476481f, 0.469814f, 0.588766f, 0.445640f, 0.609160f, 0.437785f, 0.443498f, 0.439338f, 0.487424f, 0.310942f, 0.607341f, 0.630362f, 0.312591f, 0.621999f, 0.483917f, 0.446308f, 0.477454f, 0.331028f, 0.592608f, 0.653297f, 0.322368f, 0.599377f, 0.497354f, 0.443447f, 0.477781f, 0.384002f, 0.591587f, 0.610287f, 0.328537f, 0.567630f, 0.499369f, 0.421961f, 0.536492f, 0.345379f, 0.586450f, 0.600541f, 0.312965f, 0.609437f, 0.451750f, 0.359685f, 0.553321f, 0.464992f, 0.524025f, 0.522507f, 0.582135f, 0.425124f, 0.459696f, 0.394679f, 0.519051f, 0.411226f, 0.539772f, 0.505003f, 0.587681f, 0.469383f, 0.451681f, 0.430062f, 0.541843f, 0.420929f, 0.542240f, 0.487570f, 0.567067f, 0.419708f, 0.456288f, 0.412096f, 0.527592f, 0.467870f, 0.545021f, 0.547842f, 0.573135f, 0.448166f, 0.581220f, 0.559255f, 0.469802f, 0.489935f, 0.557197f, 0.487135f, 0.377325f, 0.425637f, 0.582374f, 0.560738f, 0.425382f, 0.463129f, 0.549939f, 0.481810f, 0.350432f, 0.466049f, 0.593554f, 0.542315f, 0.482597f, 0.496969f, 0.518851f, 0.507807f, 0.366054f, 0.457476f, 0.569468f, 0.565965f, 0.444765f, 0.465404f, 0.515500f, 0.520271f, 0.337845f, 0.448357f, 0.557802f, 0.585925f, 0.426858f, 0.464044f, 0.585251f, 0.557395f, 0.433327f, 0.615342f, 0.534368f, 0.573723f, 0.426393f, 0.518102f, 0.586735f, 0.513129f, 0.371969f, 0.636735f, 0.544166f, 0.588469f, 0.433470f, 0.481894f, 0.595019f, 0.533156f, 0.396519f, 0.608115f, 0.547125f, 0.604473f, 0.441984f, 0.469765f, 0.599107f, 0.561685f, 0.347618f, 0.563457f, 0.507550f, 0.485293f, 0.545846f, 0.408434f, 0.482538f, 0.532314f, 0.498883f, 0.525126f, 0.514603f, 0.471457f, 0.539705f, 0.362410f, 0.490158f, 0.513690f, 0.494170f, 0.496909f, 0.492936f, 0.506153f, 0.565865f, 0.364727f, 0.508899f, 0.516217f, 0.558362f, 0.556920f, 0.530472f, 0.521715f, 0.554673f, 0.363830f, 0.509086f, 0.511590f, 0.552396f, 0.541486f, 0.572145f, 0.551531f, 0.471964f, 0.485188f, 0.555030f, 0.493247f, 0.376875f, 0.429387f, 0.580540f, 0.550944f, 0.435664f, 0.480675f, 0.544997f, 0.488698f, 0.344985f, 0.464878f, 0.593774f, 0.541202f, 0.484834f, 0.497316f, 0.509364f, 0.500045f, 0.357235f, 0.448933f, 0.565242f, 0.546653f, 0.459790f, 0.481954f, 0.514950f, 0.516297f, 0.344285f, 0.454476f, 0.548036f, 0.577907f, 0.427075f, 0.478978f, 0.581563f, 0.553606f, 0.426476f, 0.638442f, 0.498925f, 0.598346f, 0.444106f, 0.536998f, 0.575948f, 0.499260f, 0.371120f, 0.626981f, 0.545949f, 0.586548f, 0.428254f, 0.479753f, 0.596943f, 0.527697f, 0.401418f, 0.613028f, 0.542355f, 0.607063f, 0.447840f, 0.467102f, 0.603496f, 0.549575f, 0.364370f, 0.561534f, 0.507041f, 0.473640f, 0.547768f, 0.413960f, 0.490513f, 0.534377f, 0.497277f, 0.517772f, 0.531394f, 0.489105f, 0.531671f, 0.369343f, 0.486462f, 0.501787f, 0.494220f, 0.493498f, 0.485968f, 0.510301f, 0.559766f, 0.361474f, 0.507888f, 0.518858f, 0.564300f, 0.561990f, 0.537984f, 0.527982f, 0.539571f, 0.366920f, 0.498313f, 0.505709f, 0.538027f, 0.541246f, 0.585733f, 0.565800f, 0.441346f, 0.476255f, 0.556453f, 0.497693f, 0.363246f, 0.426799f, 0.578484f, 0.556489f, 0.436699f, 0.481177f, 0.549473f, 0.484153f, 0.355910f, 0.462010f, 0.590951f, 0.542803f, 0.470954f, 0.488994f, 0.512707f, 0.511876f, 0.358555f, 0.455953f, 0.559449f, 0.546003f, 0.462900f, 0.471080f, 0.517298f, 0.519225f, 0.345016f, 0.449149f, 0.526624f, 0.606761f, 0.427660f, 0.480775f, 0.577420f, 0.538850f, 0.426959f, 0.625509f, 0.530502f, 0.585784f, 0.432234f, 0.516800f, 0.584937f, 0.514154f, 0.373726f, 0.623740f, 0.550470f, 0.585577f, 0.436483f, 0.474799f, 0.594100f, 0.540052f, 0.402520f, 0.607686f, 0.537556f, 0.609680f, 0.439490f, 0.477886f, 0.602656f, 0.542957f, 0.350394f, 0.574553f, 0.506900f, 0.488792f, 0.539037f, 0.403028f, 0.494093f, 0.534739f, 0.494292f, 0.511628f, 0.528192f, 0.480037f, 0.546429f, 0.375120f, 0.484828f, 0.505006f, 0.495786f, 0.497935f, 0.502174f, 0.514122f, 0.541314f, 0.369540f, 0.493985f, 0.508263f, 0.550415f, 0.556157f, 0.543269f, 0.529970f, 0.562027f, 0.376526f, 0.499704f, 0.508621f, 0.536068f, 0.545993f}; + std::vector y = {0.544462f, 0.617844f, 0.506335f, 0.473482f, 0.606855f, 0.423464f, 0.544771f, 0.450451f, 0.524249f, 0.627160f, 0.497201f, 0.440288f, 0.619110f, 0.437084f, 0.563680f, 0.440037f, 0.516736f, 0.577726f, 0.523888f, 0.493471f, 0.594122f, 0.433401f, 0.585942f, 0.457686f, 0.528512f, 0.604578f, 0.472106f, 0.471486f, 0.600445f, 0.446256f, 0.622393f, 0.435442f, 0.546090f, 0.618047f, 0.504325f, 0.472246f, 0.609686f, 0.422467f, 0.546964f, 0.451166f, 0.519404f, 0.617868f, 0.491984f, 0.445771f, 0.633094f, 0.436822f, 0.559753f, 0.447209f, 0.519860f, 0.574899f, 0.525759f, 0.489339f, 0.586803f, 0.436452f, 0.577737f, 0.453299f, 0.532473f, 0.609446f, 0.471758f, 0.455772f, 0.573504f, 0.445466f, 0.602573f, 0.433307f, 0.538062f, 0.604199f, 0.500302f, 0.479569f, 0.614174f, 0.429231f, 0.522434f, 0.459369f, 0.528422f, 0.620683f, 0.485333f, 0.435606f, 0.616579f, 0.432233f, 0.565856f, 0.440093f, 0.525356f, 0.580613f, 0.529584f, 0.483095f, 0.583395f, 0.433491f, 0.593043f, 0.451879f, 0.540119f, 0.622995f, 0.472122f, 0.449888f, 0.586202f, 0.447435f, 0.611846f, 0.434879f, 0.449905f, 0.430732f, 0.474834f, 0.321674f, 0.590495f, 0.626300f, 0.319127f, 0.606006f, 0.492763f, 0.445330f, 0.490219f, 0.319940f, 0.588298f, 0.643644f, 0.317760f, 0.596360f, 0.507993f, 0.440004f, 0.490555f, 0.378128f, 0.588227f, 0.604974f, 0.329202f, 0.561987f, 0.511572f, 0.403440f, 0.542761f, 0.331792f, 0.568397f, 0.583366f, 0.333122f, 0.608456f, 0.447842f, 0.430169f, 0.480055f, 0.320521f, 0.590915f, 0.627003f, 0.314551f, 0.609320f, 0.499216f, 0.438828f, 0.485519f, 0.322134f, 0.586364f, 0.645824f, 0.326481f, 0.596989f, 0.496362f, 0.442741f, 0.492120f, 0.366111f, 0.601604f, 0.615566f, 0.326354f, 0.567173f, 0.496946f, 0.422179f, 0.533144f, 0.342588f, 0.590482f, 0.605923f, 0.318055f, 0.610401f, 0.441356f, 0.431701f, 0.488343f, 0.311828f, 0.606159f, 0.632821f, 0.317863f, 0.629084f, 0.495613f, 0.441177f, 0.473223f, 0.335484f, 0.579139f, 0.646878f, 0.321269f, 0.595437f, 0.504999f, 0.443626f, 0.498154f, 0.369326f, 0.588410f, 0.600189f, 0.322347f, 0.562676f, 0.508419f, 0.405342f, 0.533092f, 0.335876f, 0.570568f, 0.589600f, 0.330741f, 0.609168f, 0.456943f, 0.365603f, 0.555030f, 0.454344f, 0.526263f, 0.519062f, 0.578652f, 0.425453f, 0.464039f, 0.391848f, 0.518985f, 0.419419f, 0.541410f, 0.514459f, 0.586459f, 0.470210f, 0.460338f, 0.408599f, 0.539512f, 0.446249f, 0.551945f, 0.511356f, 0.575513f, 0.424325f, 0.452212f, 0.418205f, 0.525148f, 0.459799f, 0.536327f, 0.541881f, 0.571451f, 0.452969f, 0.454154f, 0.354641f, 0.553889f, 0.451027f, 0.536270f, 0.521832f, 0.590756f, 0.429859f, 0.459101f, 0.394962f, 0.512076f, 0.419296f, 0.535702f, 0.516757f, 0.585606f, 0.478117f, 0.458365f, 0.422929f, 0.531943f, 0.447581f, 0.546387f, 0.511705f, 0.564350f, 0.425332f, 0.463274f, 0.429223f, 0.525922f, 0.452328f, 0.539095f, 0.534372f, 0.563738f, 0.449120f, 0.451750f, 0.359685f, 0.553321f, 0.464992f, 0.524025f, 0.522507f, 0.582135f, 0.425124f, 0.459696f, 0.394679f, 0.519051f, 0.411226f, 0.539772f, 0.505003f, 0.587681f, 0.469383f, 0.451681f, 0.430062f, 0.541843f, 0.420929f, 0.542240f, 0.487570f, 0.567067f, 0.419708f, 0.456288f, 0.412096f, 0.527592f, 0.467870f, 0.545021f, 0.547842f, 0.573135f, 0.448166f, 0.581220f, 0.559255f, 0.469802f, 0.489935f, 0.557197f, 0.487135f, 0.377325f, 0.425637f, 0.582374f, 0.560738f, 0.425382f, 0.463129f, 0.549939f, 0.481810f, 0.350432f, 0.466049f, 0.593554f, 0.542315f, 0.482597f, 0.496969f, 0.518851f, 0.507807f, 0.366054f, 0.457476f, 0.569468f, 0.565965f, 0.444765f, 0.465404f, 0.515500f, 0.520271f, 0.337845f, 0.448357f, 0.586343f, 0.566462f, 0.444339f, 0.481474f, 0.557556f, 0.495837f, 0.368487f, 0.425850f, 0.580159f, 0.565990f, 0.400882f, 0.462578f, 0.551037f, 0.497924f, 0.338502f, 0.468483f, 0.592753f, 0.536897f, 0.481975f, 0.489485f, 0.519290f, 0.509298f, 0.366838f, 0.461538f, 0.567139f, 0.559419f, 0.458050f, 0.468739f, 0.514875f, 0.512271f, 0.346335f, 0.449357f, 0.583058f, 0.557532f, 0.454426f, 0.492673f, 0.551748f, 0.496414f, 0.364023f, 0.430048f, 0.579431f, 0.565100f, 0.420761f, 0.466297f, 0.551315f, 0.487418f, 0.348148f, 0.461136f, 0.585687f, 0.535194f, 0.485465f, 0.488622f, 0.513327f, 0.508844f, 0.368049f, 0.455823f, 0.554855f, 0.560589f, 0.456398f, 0.477641f, 0.507017f, 0.518069f, 0.338229f, 0.444624f, 0.500594f, 0.616610f, 0.439949f, 0.495561f, 0.569213f, 0.540425f, 0.422667f, 0.627919f, 0.514283f, 0.584446f, 0.441141f, 0.528331f, 0.577047f, 0.508969f, 0.372295f, 0.646734f, 0.536256f, 0.591823f, 0.428652f, 0.485852f, 0.592863f, 0.525360f, 0.399985f, 0.623408f, 0.552463f, 0.606841f, 0.448560f, 0.466321f, 0.600628f, 0.566464f, 0.356481f, 0.551351f, 0.548036f, 0.577907f, 0.427075f, 0.478978f, 0.581563f, 0.553606f, 0.426476f, 0.638442f, 0.498925f, 0.598346f, 0.444106f, 0.536998f, 0.575948f, 0.499260f, 0.371120f, 0.626981f, 0.545949f, 0.586548f, 0.428254f, 0.479753f, 0.596943f, 0.527697f, 0.401418f, 0.613028f, 0.542355f, 0.607063f, 0.447840f, 0.467102f, 0.603496f, 0.549575f, 0.364370f, 0.561534f, 0.532692f, 0.601573f, 0.425963f, 0.477495f, 0.573122f, 0.544325f, 0.422438f, 0.629794f, 0.512145f, 0.593241f, 0.436187f, 0.532146f, 0.582008f, 0.499410f, 0.366728f, 0.631277f, 0.550263f, 0.590346f, 0.430967f, 0.477189f, 0.600022f, 0.528313f, 0.406504f, 0.603355f, 0.537075f, 0.605495f, 0.437735f, 0.474413f, 0.601068f, 0.542204f, 0.348555f, 0.581430f, 0.499619f, 0.480920f, 0.536032f, 0.413380f, 0.478027f, 0.524393f, 0.490201f, 0.530954f, 0.517442f, 0.475326f, 0.541763f, 0.366450f, 0.498398f, 0.509411f, 0.503732f, 0.490468f, 0.488084f, 0.505941f, 0.554614f, 0.371690f, 0.503635f, 0.510325f, 0.557424f, 0.564303f, 0.534730f, 0.536543f, 0.563296f, 0.362277f, 0.498957f, 0.508357f, 0.538003f, 0.554638f, 0.514150f, 0.481676f, 0.543535f, 0.414778f, 0.478296f, 0.529467f, 0.496600f, 0.522262f, 0.522734f, 0.480361f, 0.534209f, 0.379264f, 0.485836f, 0.500082f, 0.498644f, 0.501901f, 0.474729f, 0.503193f, 0.560206f, 0.362595f, 0.515144f, 0.512647f, 0.557224f, 0.567242f, 0.539217f, 0.533273f, 0.538641f, 0.373064f, 0.495733f, 0.499786f, 0.532998f, 0.547731f, 0.506900f, 0.488792f, 0.539037f, 0.403028f, 0.494093f, 0.534739f, 0.494292f, 0.511628f, 0.528192f, 0.480037f, 0.546429f, 0.375120f, 0.484828f, 0.505006f, 0.495786f, 0.497935f, 0.502174f, 0.514122f, 0.541314f, 0.369540f, 0.493985f, 0.508263f, 0.550415f, 0.556157f, 0.543269f, 0.529970f, 0.562027f, 0.376526f, 0.499704f, 0.508621f, 0.536068f, 0.545993f}; // {2, 3, 18, 8} std::vector present_key = {0.938256f, 0.244570f, 0.458212f, 0.757407f, 0.203621f, 0.566312f, 0.185817f, 0.104736f, 0.116559f, 0.357639f, 0.004655f, 0.424854f, 0.664197f, 0.401688f, 0.085795f, 0.062689f, 0.278117f, 0.169313f, 0.965095f, 0.151230f, 0.805462f, 0.586108f, 0.569287f, 0.512081f, 0.971763f, 0.363845f, 0.787916f, 0.555294f, 0.395634f, 0.955466f, 0.598316f, 0.118917f, 0.417539f, 0.781582f, 0.693747f, 0.916340f, 0.259377f, 0.758194f, 0.459875f, 0.573610f, 0.955047f, 0.979286f, 0.861591f, 0.359097f, 0.887701f, 0.638609f, 0.429997f, 0.035743f, 0.770128f, 0.502106f, 0.786188f, 0.748023f, 0.793567f, 0.300651f, 0.800799f, 0.548846f, 0.473326f, 0.675126f, 0.021359f, 0.102317f, 0.292177f, 0.982990f, 0.139746f, 0.330596f, 0.051053f, 0.331269f, 0.320326f, 0.946807f, 0.845154f, 0.382764f, 0.024769f, 0.831031f, 0.660536f, 0.152364f, 0.996071f, 0.100233f, 0.867115f, 0.294266f, 0.435353f, 0.795457f, 0.677508f, 0.937864f, 0.621140f, 0.097810f, 0.884360f, 0.769156f, 0.711870f, 0.053734f, 0.396223f, 0.167436f, 0.821904f, 0.700529f, 0.883078f, 0.966575f, 0.774748f, 0.994233f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.614770f, 0.037130f, 0.014252f, 0.342104f, 0.823472f, 0.866135f, 0.960813f, 0.065121f, 0.044571f, 0.913284f, 0.305047f, 0.557987f, 0.982445f, 0.400449f, 0.665871f, 0.400880f, 0.768195f, 0.527715f, 0.237523f, 0.271306f, 0.258059f, 0.532320f, 0.703189f, 0.949280f, 0.694087f, 0.781193f, 0.168926f, 0.374063f, 0.413780f, 0.686380f, 0.295892f, 0.303292f, 0.355889f, 0.810302f, 0.577590f, 0.075277f, 0.078246f, 0.371287f, 0.766591f, 0.688683f, 0.707982f, 0.767210f, 0.287153f, 0.548256f, 0.543353f, 0.739632f, 0.956871f, 0.277990f, 0.793282f, 0.659971f, 0.580238f, 0.774880f, 0.944032f, 0.036691f, 0.147400f, 0.756287f, 0.083791f, 0.516124f, 0.219861f, 0.274296f, 0.701840f, 0.030193f, 0.873319f, 0.444479f, 0.502393f, 0.540048f, 0.645544f, 0.344857f, 0.101107f, 0.318379f, 0.168142f, 0.556133f, 0.318029f, 0.958067f, 0.965734f, 0.620126f, 0.617497f, 0.985379f, 0.887283f, 0.765070f, 0.313591f, 0.365539f, 0.201267f, 0.487148f, 0.990369f, 0.912151f, 0.118349f, 0.025190f, 0.898638f, 0.537170f, 0.200190f, 0.673653f, 0.644223f, 0.122086f, 0.259600f, 0.060078f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.209860f, 0.132306f, 0.193236f, 0.685467f, 0.049500f, 0.101855f, 0.134174f, 0.316541f, 0.298750f, 0.255064f, 0.750537f, 0.998023f, 0.533978f, 0.944203f, 0.396610f, 0.106682f, 0.408774f, 0.296128f, 0.493407f, 0.657044f, 0.461050f, 0.935161f, 0.884765f, 0.701978f, 0.489685f, 0.131687f, 0.397014f, 0.704402f, 0.284886f, 0.103988f, 0.907898f, 0.709051f, 0.615276f, 0.792499f, 0.835646f, 0.483459f, 0.881188f, 0.916419f, 0.271551f, 0.607545f, 0.526584f, 0.537946f, 0.937663f, 0.305189f, 0.983434f, 0.902131f, 0.458723f, 0.817453f, 0.769047f, 0.677895f, 0.319834f, 0.196451f, 0.671528f, 0.842973f, 0.016253f, 0.642803f, 0.442873f, 0.898088f, 0.321473f, 0.474185f, 0.514767f, 0.140440f, 0.712892f, 0.830476f, 0.057909f, 0.291389f, 0.038045f, 0.956544f, 0.667169f, 0.964200f, 0.531494f, 0.802069f, 0.374414f, 0.353819f, 0.378268f, 0.657862f, 0.359453f, 0.900367f, 0.983275f, 0.030427f, 0.193623f, 0.112250f, 0.042364f, 0.227741f, 0.446793f, 0.836990f, 0.221824f, 0.493945f, 0.929619f, 0.667215f, 0.798079f, 0.550994f, 0.980466f, 0.588662f, 0.045511f, 0.197983f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.404774f, 0.601277f, 0.771931f, 0.413086f, 0.710058f, 0.789869f, 0.317260f, 0.979270f, 0.649656f, 0.880998f, 0.555938f, 0.741603f, 0.770544f, 0.908248f, 0.150350f, 0.558283f, 0.428379f, 0.923159f, 0.105095f, 0.982574f, 0.875451f, 0.073826f, 0.490966f, 0.717560f, 0.738152f, 0.906494f, 0.799865f, 0.310930f, 0.498435f, 0.701786f, 0.138437f, 0.193991f, 0.481042f, 0.298246f, 0.862559f, 0.586277f, 0.348665f, 0.848833f, 0.804878f, 0.998355f, 0.847308f, 0.414457f, 0.127499f, 0.840641f, 0.059758f, 0.350271f, 0.919738f, 0.960766f, 0.640565f, 0.688648f, 0.042454f, 0.514480f, 0.546868f, 0.340101f, 0.068597f, 0.228908f, 0.357984f, 0.435142f, 0.590927f, 0.722392f, 0.317632f, 0.328954f, 0.019692f, 0.040875f, 0.257822f, 0.740245f, 0.628314f, 0.769789f, 0.768919f, 0.856567f, 0.720319f, 0.979011f, 0.898825f, 0.586717f, 0.588158f, 0.034267f, 0.998527f, 0.131576f, 0.740347f, 0.821015f, 0.373055f, 0.196852f, 0.098760f, 0.748606f, 0.452654f, 0.713718f, 0.915408f, 0.146584f, 0.919171f, 0.411626f, 0.305267f, 0.943062f, 0.990652f, 0.198892f, 0.656838f, 0.106495f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f, 0.650914f, 0.827313f, 0.684499f, 0.417333f, 0.383066f, 0.393122f, 0.589712f, 0.881567f, 0.929066f, 0.053530f, 0.181622f, 0.112224f, 0.193335f, 0.346608f, 0.506532f, 0.629461f, 0.732142f, 0.890112f, 0.989088f, 0.662856f, 0.845365f, 0.778039f, 0.307532f, 0.875692f, 0.042763f, 0.000367f, 0.273733f, 0.462098f, 0.638363f, 0.101770f, 0.673010f, 0.801816f, 0.185313f, 0.415125f, 0.519985f, 0.451807f, 0.799830f, 0.960522f, 0.798953f, 0.077993f, 0.804936f, 0.066596f, 0.235970f, 0.153097f, 0.197519f, 0.528315f, 0.671690f, 0.470321f, 0.959696f, 0.240292f, 0.763140f, 0.870182f, 0.562066f, 0.456223f, 0.596184f, 0.428810f, 0.555194f, 0.416934f, 0.400470f, 0.695346f, 0.092851f, 0.166542f, 0.851198f, 0.771077f, 0.281454f, 0.377269f, 0.926027f, 0.818077f, 0.614346f, 0.221490f, 0.044252f, 0.431258f, 0.672627f, 0.828480f, 0.852689f, 0.032776f, 0.244157f, 0.339095f, 0.188732f, 0.802975f, 0.767466f, 0.516833f, 0.982926f, 0.144059f, 0.899652f, 0.116463f, 0.163182f, 0.696219f, 0.109570f, 0.565845f, 0.420234f, 0.728474f, 0.900675f, 0.769872f, 0.849690f, 0.032945f, 0.966147f, 0.791880f, 0.675689f, 0.244889f, 0.216457f, 0.166048f, 0.922757f, 0.294077f, 0.453094f, 0.493958f, 0.778172f, 0.844235f, 0.139073f, 0.426904f, 0.842855f, 0.818033f, 0.102414f, 0.156383f, 0.304199f, 0.075359f, 0.424663f, 0.107618f, 0.568218f, 0.246557f, 0.596433f, 0.117526f, 0.975884f, 0.932561f, 0.391797f, 0.242179f, 0.250398f, 0.483394f, 0.039993f, 0.639705f, 0.408303f, 0.377407f, 0.809365f, 0.709035f, 0.954334f, 0.351936f, 0.897543f, 0.769967f, 0.357425f, 0.621665f, 0.288570f, 0.874400f, 0.112427f, 0.212434f, 0.310196f, 0.515433f, 0.415953f, 0.231255f, 0.307874f, 0.945431f, 0.294181f, 0.353904f, 0.003710f, 0.845078f, 0.154841f, 0.204144f, 0.255265f, 0.884622f, 0.206451f, 0.797526f, 0.808049f, 0.927021f, 0.115561f, 0.217279f, 0.742898f, 0.196001f, 0.286330f, 0.166742f, 0.172697f, 0.481553f, 0.109683f, 0.321698f, 0.426594f, 0.024548f, 0.388333f, 0.094122f, 0.493579f, 0.825738f, 0.818422f, 0.080449f, 0.601228f, 0.834586f, 0.237973f, 0.761927f, 0.890764f, 0.806124f, 0.107301f, 0.009060f, 0.191724f, 0.270477f, 0.616183f, 0.384273f, 0.703407f, 0.353075f, 0.154425f, 0.312690f, 0.884324f, 0.958532f, 0.207513f, 0.788468f, 0.273349f, 0.887132f, 0.165546f, 0.665960f, 0.084211f, 0.973893f, 0.700633f, 0.841816f, 0.566669f, 0.476801f, 0.621882f, 0.528742f, 0.469384f, 0.759450f, 0.178201f, 0.171172f, 0.431843f, 0.320748f, 0.074125f, 0.844471f, 0.771603f, 0.543921f, 0.979325f, 0.072600f, 0.766669f, 0.266370f, 0.368599f, 0.219279f, 0.789038f, 0.144240f, 0.840017f, 0.661578f, 0.059023f, 0.810982f, 0.627756f, 0.904982f, 0.748722f, 0.561121f, 0.836547f, 0.278050f, 0.183033f, 0.403026f, 0.745233f, 0.526907f, 0.487676f, 0.000546f, 0.425402f, 0.063554f, 0.208253f, 0.932394f, 0.215398f, 0.858338f, 0.802893f, 0.159146f, 0.605712f, 0.115662f, 0.727888f, 0.637462f, 0.811939f, 0.479385f, 0.914863f, 0.049349f, 0.292889f, 0.715053f, 0.418109f, 0.172951f, 0.107211f, 0.817339f, 0.473143f, 0.882284f, 0.733289f, 0.409726f, 0.373511f, 0.515638f, 0.889060f, 0.737279f, 0.005153f, 0.694158f, 0.919507f, 0.710456f, 0.177006f, 0.483518f, 0.140316f, 0.358995f, 0.937117f, 0.923305f, 0.282837f, 0.339631f}; // {2, 3, 18, 8} @@ -1116,7 +1149,7 @@ TEST(AttentionTest, Attention4DWithMask3DPastAndPresentQkMatmulCausal) { std::vector k = {-0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; // {2, 3, 6, 4} std::vector v = {-0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; - // {2, 1, 4, 13} + // {2, 1, 4, 18} std::vector m = {-0.454545f, -0.444930f, -0.435315f, -0.425699f, -0.416084f, -0.406469f, -0.396853f, -0.387238f, -0.377622f, -0.368007f, -0.358392f, -0.348776f, -0.339161f, -0.329545f, -0.319930f, -0.310315f, -0.300699f, -0.291084f, -0.281469f, -0.271853f, -0.262238f, -0.252622f, -0.243007f, -0.233392f, -0.223776f, -0.214161f, -0.204545f, -0.194930f, -0.185315f, -0.175699f, -0.166084f, -0.156469f, -0.146853f, -0.137238f, -0.127622f, -0.118007f, -0.108392f, -0.098776f, -0.089161f, -0.079545f, -0.069930f, -0.060315f, -0.050699f, -0.041084f, -0.031469f, -0.021853f, -0.012238f, -0.002622f, 0.006993f, 0.016608f, 0.026224f, 0.035839f, 0.045455f, 0.055070f, 0.064685f, 0.074301f, 0.083916f, 0.093531f, 0.103147f, 0.112762f, 0.122378f, 0.131993f, 0.141608f, 0.151224f, 0.160839f, 0.170455f, 0.180070f, 0.189685f, 0.199301f, 0.208916f, 0.218531f, 0.228147f, 0.237762f, 0.247378f, 0.256993f, 0.266608f, 0.276224f, 0.285839f, 0.295455f, 0.305070f, 0.314685f, 0.324301f, 0.333916f, 0.343531f, 0.353147f, 0.362762f, 0.372378f, 0.381993f, 0.391608f, 0.401224f, 0.410839f, 0.420455f, 0.430070f, 0.439685f, 0.449301f, 0.458916f, 0.468531f, 0.478147f, 0.487762f, 0.497378f, 0.506993f, 0.516608f, 0.526224f, 0.535839f}; // {2, 3, 12, 4} std::vector past_key = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f}; @@ -1132,7 +1165,7 @@ TEST(AttentionTest, Attention4DWithMask3DPastAndPresentQkMatmulCausal) { // {2, 3, 4, 4} std::vector y = {-0.393782f, -0.387694f, -0.381606f, -0.375519f, -0.397492f, -0.391304f, -0.385116f, -0.378928f, -0.397474f, -0.391207f, -0.384941f, -0.378674f, -0.394849f, -0.388519f, -0.382190f, -0.375860f, -0.226271f, -0.220186f, -0.214101f, -0.208016f, -0.230042f, -0.223857f, -0.217672f, -0.211488f, -0.230104f, -0.223841f, -0.217577f, -0.211314f, -0.227525f, -0.221197f, -0.214870f, -0.208543f, -0.058757f, -0.052674f, -0.046592f, -0.040510f, -0.062587f, -0.056406f, -0.050224f, -0.044042f, -0.062730f, -0.056470f, -0.050209f, -0.043949f, -0.060198f, -0.053873f, -0.047548f, -0.041223f, 0.108760f, 0.114840f, 0.120919f, 0.126999f, 0.104873f, 0.111051f, 0.117229f, 0.123408f, 0.104648f, 0.110906f, 0.117163f, 0.123421f, 0.107131f, 0.113454f, 0.119777f, 0.126099f, 0.276279f, 0.282356f, 0.288433f, 0.294510f, 0.272337f, 0.278512f, 0.284687f, 0.290862f, 0.272031f, 0.278286f, 0.284540f, 0.290794f, 0.274463f, 0.280783f, 0.287104f, 0.293424f, 0.443800f, 0.449874f, 0.455949f, 0.462023f, 0.439807f, 0.445978f, 0.452150f, 0.458321f, 0.439418f, 0.445669f, 0.451921f, 0.458172f, 0.441797f, 0.448115f, 0.454433f, 0.460751f}; - // {2, 3, 13, 4} + // {2, 3, 12, 4} std::vector present_key = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; // {2, 3, 18, 8} std::vector present_value = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; @@ -1151,28 +1184,28 @@ TEST(AttentionTest, Attention4DWithMask3DPastAndPresentQkMatmulCausal) { ); } -TEST(AttentionTest, Attention4DWithMask4DPastAndPresentQkMatmul) { - int batch_size = 2; // Q.shape[0] - int q_num_heads = 3; // Q.shape[1] - int q_sequence_length = 4; // Q.shape[2] - int head_size = 4; // Q.shape[3] - int kv_sequence_length = 6; // K.shape[2] and V.shape[2] - int kv_num_heads = 3; // K.shape[1] and V.shape[1] - int v_head_size = 4; // V.shape[3] - int past_sequence_length = 7; // past_key.shape[2] and past_value.shape[2] +TEST(AttentionTest, TestAttention4DWithPastAndPresentQkMatmulBias4DMaskCausal) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 3; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 8; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 8; // V.shape[3] + int past_sequence_length = 12; // past_key.shape[2] and past_value.shape[2] - // {2, 3, 4, 4} - std::vector q = {-0.454545f, -0.444129f, -0.433712f, -0.423295f, -0.412879f, -0.402462f, -0.392045f, -0.381629f, -0.371212f, -0.360795f, -0.350379f, -0.339962f, -0.329545f, -0.319129f, -0.308712f, -0.298295f, -0.287879f, -0.277462f, -0.267045f, -0.256629f, -0.246212f, -0.235795f, -0.225379f, -0.214962f, -0.204545f, -0.194129f, -0.183712f, -0.173295f, -0.162879f, -0.152462f, -0.142045f, -0.131629f, -0.121212f, -0.110795f, -0.100379f, -0.089962f, -0.079545f, -0.069129f, -0.058712f, -0.048295f, -0.037879f, -0.027462f, -0.017045f, -0.006629f, 0.003788f, 0.014205f, 0.024621f, 0.035038f, 0.045455f, 0.055871f, 0.066288f, 0.076705f, 0.087121f, 0.097538f, 0.107955f, 0.118371f, 0.128788f, 0.139205f, 0.149621f, 0.160038f, 0.170455f, 0.180871f, 0.191288f, 0.201705f, 0.212121f, 0.222538f, 0.232955f, 0.243371f, 0.253788f, 0.264205f, 0.274621f, 0.285038f, 0.295455f, 0.305871f, 0.316288f, 0.326705f, 0.337121f, 0.347538f, 0.357955f, 0.368371f, 0.378788f, 0.389205f, 0.399621f, 0.410038f, 0.420455f, 0.430871f, 0.441288f, 0.451705f, 0.462121f, 0.472538f, 0.482955f, 0.493371f, 0.503788f, 0.514205f, 0.524621f, 0.535038f}; - // {2, 3, 6, 4} - std::vector k = {-0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; + // {2, 3, 4, 8} + std::vector q = {0.548814f, 0.715189f, 0.602763f, 0.544883f, 0.423655f, 0.645894f, 0.437587f, 0.891773f, 0.963663f, 0.383442f, 0.791725f, 0.528895f, 0.568045f, 0.925597f, 0.071036f, 0.087129f, 0.020218f, 0.832620f, 0.778157f, 0.870012f, 0.978618f, 0.799159f, 0.461479f, 0.780529f, 0.118274f, 0.639921f, 0.143353f, 0.944669f, 0.521848f, 0.414662f, 0.264556f, 0.774234f, 0.456150f, 0.568434f, 0.018790f, 0.617635f, 0.612096f, 0.616934f, 0.943748f, 0.681820f, 0.359508f, 0.437032f, 0.697631f, 0.060225f, 0.666767f, 0.670638f, 0.210383f, 0.128926f, 0.315428f, 0.363711f, 0.570197f, 0.438602f, 0.988374f, 0.102045f, 0.208877f, 0.161310f, 0.653108f, 0.253292f, 0.466311f, 0.244426f, 0.158970f, 0.110375f, 0.656330f, 0.138183f, 0.196582f, 0.368725f, 0.820993f, 0.097101f, 0.837945f, 0.096098f, 0.976459f, 0.468651f, 0.976761f, 0.604846f, 0.739264f, 0.039188f, 0.282807f, 0.120197f, 0.296140f, 0.118728f, 0.317983f, 0.414263f, 0.064147f, 0.692472f, 0.566601f, 0.265390f, 0.523248f, 0.093941f, 0.575947f, 0.929296f, 0.318569f, 0.667410f, 0.131798f, 0.716327f, 0.289406f, 0.183191f, 0.586513f, 0.020108f, 0.828940f, 0.004695f, 0.677817f, 0.270008f, 0.735194f, 0.962189f, 0.248753f, 0.576157f, 0.592042f, 0.572252f, 0.223082f, 0.952749f, 0.447125f, 0.846409f, 0.699479f, 0.297437f, 0.813798f, 0.396506f, 0.881103f, 0.581273f, 0.881735f, 0.692532f, 0.725254f, 0.501324f, 0.956084f, 0.643990f, 0.423855f, 0.606393f, 0.019193f, 0.301575f, 0.660174f, 0.290078f, 0.618015f, 0.428769f, 0.135474f, 0.298282f, 0.569965f, 0.590873f, 0.574325f, 0.653201f, 0.652103f, 0.431418f, 0.896547f, 0.367562f, 0.435865f, 0.891923f, 0.806194f, 0.703889f, 0.100227f, 0.919483f, 0.714241f, 0.998847f, 0.149448f, 0.868126f, 0.162493f, 0.615560f, 0.123820f, 0.848008f, 0.807319f, 0.569101f, 0.407183f, 0.069167f, 0.697429f, 0.453543f, 0.722056f, 0.866382f, 0.975522f, 0.855803f, 0.011714f, 0.359978f, 0.729991f, 0.171630f, 0.521037f, 0.054338f, 0.199997f, 0.018522f, 0.793698f, 0.223925f, 0.345352f, 0.928081f, 0.704414f, 0.031839f, 0.164694f, 0.621478f, 0.577229f, 0.237893f, 0.934214f, 0.613966f, 0.535633f, 0.589910f, 0.730122f, 0.311945f, 0.398221f, 0.209844f}; + // {2, 3, 6, 8} + std::vector k = {0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f}; // {2, 3, 6, 4} - std::vector v = {-0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; - // {2, 3, 4, 13} - std::vector m = {-0.454545f, -0.451340f, -0.448135f, -0.444930f, -0.441725f, -0.438520f, -0.435315f, -0.432110f, -0.428904f, -0.425699f, -0.422494f, -0.419289f, -0.416084f, -0.412879f, -0.409674f, -0.406469f, -0.403263f, -0.400058f, -0.396853f, -0.393648f, -0.390443f, -0.387238f, -0.384033f, -0.380828f, -0.377622f, -0.374417f, -0.371212f, -0.368007f, -0.364802f, -0.361597f, -0.358392f, -0.355186f, -0.351981f, -0.348776f, -0.345571f, -0.342366f, -0.339161f, -0.335956f, -0.332751f, -0.329545f, -0.326340f, -0.323135f, -0.319930f, -0.316725f, -0.313520f, -0.310315f, -0.307110f, -0.303904f, -0.300699f, -0.297494f, -0.294289f, -0.291084f, -0.287879f, -0.284674f, -0.281469f, -0.278263f, -0.275058f, -0.271853f, -0.268648f, -0.265443f, -0.262238f, -0.259033f, -0.255828f, -0.252622f, -0.249417f, -0.246212f, -0.243007f, -0.239802f, -0.236597f, -0.233392f, -0.230186f, -0.226981f, -0.223776f, -0.220571f, -0.217366f, -0.214161f, -0.210956f, -0.207751f, -0.204545f, -0.201340f, -0.198135f, -0.194930f, -0.191725f, -0.188520f, -0.185315f, -0.182110f, -0.178904f, -0.175699f, -0.172494f, -0.169289f, -0.166084f, -0.162879f, -0.159674f, -0.156469f, -0.153263f, -0.150058f, -0.146853f, -0.143648f, -0.140443f, -0.137238f, -0.134033f, -0.130828f, -0.127622f, -0.124417f, -0.121212f, -0.118007f, -0.114802f, -0.111597f, -0.108392f, -0.105186f, -0.101981f, -0.098776f, -0.095571f, -0.092366f, -0.089161f, -0.085956f, -0.082751f, -0.079545f, -0.076340f, -0.073135f, -0.069930f, -0.066725f, -0.063520f, -0.060315f, -0.057110f, -0.053904f, -0.050699f, -0.047494f, -0.044289f, -0.041084f, -0.037879f, -0.034674f, -0.031469f, -0.028263f, -0.025058f, -0.021853f, -0.018648f, -0.015443f, -0.012238f, -0.009033f, -0.005828f, -0.002622f, 0.000583f, 0.003788f, 0.006993f, 0.010198f, 0.013403f, 0.016608f, 0.019814f, 0.023019f, 0.026224f, 0.029429f, 0.032634f, 0.035839f, 0.039044f, 0.042249f, 0.045455f, 0.048660f, 0.051865f, 0.055070f, 0.058275f, 0.061480f, 0.064685f, 0.067890f, 0.071096f, 0.074301f, 0.077506f, 0.080711f, 0.083916f, 0.087121f, 0.090326f, 0.093531f, 0.096737f, 0.099942f, 0.103147f, 0.106352f, 0.109557f, 0.112762f, 0.115967f, 0.119172f, 0.122378f, 0.125583f, 0.128788f, 0.131993f, 0.135198f, 0.138403f, 0.141608f, 0.144814f, 0.148019f, 0.151224f, 0.154429f, 0.157634f, 0.160839f, 0.164044f, 0.167249f, 0.170455f, 0.173660f, 0.176865f, 0.180070f, 0.183275f, 0.186480f, 0.189685f, 0.192890f, 0.196096f, 0.199301f, 0.202506f, 0.205711f, 0.208916f, 0.212121f, 0.215326f, 0.218531f, 0.221737f, 0.224942f, 0.228147f, 0.231352f, 0.234557f, 0.237762f, 0.240967f, 0.244172f, 0.247378f, 0.250583f, 0.253788f, 0.256993f, 0.260198f, 0.263403f, 0.266608f, 0.269814f, 0.273019f, 0.276224f, 0.279429f, 0.282634f, 0.285839f, 0.289044f, 0.292249f, 0.295455f, 0.298660f, 0.301865f, 0.305070f, 0.308275f, 0.311480f, 0.314685f, 0.317890f, 0.321096f, 0.324301f, 0.327506f, 0.330711f, 0.333916f, 0.337121f, 0.340326f, 0.343531f, 0.346737f, 0.349942f, 0.353147f, 0.356352f, 0.359557f, 0.362762f, 0.365967f, 0.369172f, 0.372378f, 0.375583f, 0.378788f, 0.381993f, 0.385198f, 0.388403f, 0.391608f, 0.394814f, 0.398019f, 0.401224f, 0.404429f, 0.407634f, 0.410839f, 0.414044f, 0.417249f, 0.420455f, 0.423660f, 0.426865f, 0.430070f, 0.433275f, 0.436480f, 0.439685f, 0.442890f, 0.446096f, 0.449301f, 0.452506f, 0.455711f, 0.458916f, 0.462121f, 0.465326f, 0.468531f, 0.471737f, 0.474942f, 0.478147f, 0.481352f, 0.484557f, 0.487762f, 0.490967f, 0.494172f, 0.497378f, 0.500583f, 0.503788f, 0.506993f, 0.510198f, 0.513403f, 0.516608f, 0.519814f, 0.523019f, 0.526224f, 0.529429f, 0.532634f, 0.535839f, 0.539044f, 0.542249f}; - // {2, 3, 12, 4} - std::vector past_key = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f}; - // {2, 3, 12, 4} - std::vector past_value = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f}; + std::vector v = {0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f}; + // {2, 3, 4, 18} + std::vector m = {0.966147f, 0.791880f, 0.675689f, 0.244889f, 0.216457f, 0.166048f, 0.922757f, 0.294077f, 0.453094f, 0.493958f, 0.778172f, 0.844235f, 0.139073f, 0.426904f, 0.842855f, 0.818033f, 0.102414f, 0.156383f, 0.304199f, 0.075359f, 0.424663f, 0.107618f, 0.568218f, 0.246557f, 0.596433f, 0.117526f, 0.975884f, 0.932561f, 0.391797f, 0.242179f, 0.250398f, 0.483394f, 0.039993f, 0.639705f, 0.408303f, 0.377407f, 0.809365f, 0.709035f, 0.954334f, 0.351936f, 0.897543f, 0.769967f, 0.357425f, 0.621665f, 0.288570f, 0.874400f, 0.112427f, 0.212434f, 0.183033f, 0.403026f, 0.745233f, 0.526907f, 0.487676f, 0.000546f, 0.425402f, 0.063554f, 0.208253f, 0.932394f, 0.215398f, 0.858338f, 0.802893f, 0.159146f, 0.605712f, 0.115662f, 0.727888f, 0.637462f, 0.811939f, 0.479385f, 0.914863f, 0.049349f, 0.292889f, 0.715053f, 0.418109f, 0.172951f, 0.107211f, 0.817339f, 0.473143f, 0.882284f, 0.733289f, 0.409726f, 0.373511f, 0.515638f, 0.889060f, 0.737279f, 0.005153f, 0.694158f, 0.919507f, 0.710456f, 0.177006f, 0.483518f, 0.140316f, 0.358995f, 0.937117f, 0.923305f, 0.282837f, 0.339631f, 0.600213f, 0.963197f, 0.147801f, 0.256917f, 0.873557f, 0.491892f, 0.898961f, 0.185518f, 0.532669f, 0.326270f, 0.316543f, 0.446877f, 0.433077f, 0.357347f, 0.914971f, 0.731744f, 0.727547f, 0.289913f, 0.577709f, 0.779179f, 0.795590f, 0.344530f, 0.770873f, 0.735894f, 0.141506f, 0.865945f, 0.441321f, 0.486410f, 0.448369f, 0.567846f, 0.621169f, 0.498180f, 0.866789f, 0.627735f, 0.401428f, 0.416692f, 0.810839f, 0.348192f, 0.211455f, 0.059383f, 0.876027f, 0.918546f, 0.120120f, 0.334474f, 0.175372f, 0.115898f, 0.899867f, 0.056877f, 0.980486f, 0.096451f, 0.863471f, 0.566506f, 0.367917f, 0.342342f, 0.757364f, 0.314573f, 0.657319f, 0.517326f, 0.484966f, 0.901162f, 0.554645f, 0.826862f, 0.725574f, 0.038557f, 0.773110f, 0.216870f, 0.903150f, 0.042924f, 0.333072f, 0.099733f, 0.475589f, 0.820022f, 0.298187f, 0.150935f, 0.330267f, 0.813880f, 0.140384f, 0.227362f, 0.068852f, 0.705710f, 0.395233f, 0.310840f, 0.718626f, 0.335978f, 0.727771f, 0.815199f, 0.217663f, 0.973819f, 0.162358f, 0.290841f, 0.179795f, 0.345506f, 0.480061f, 0.522176f, 0.853606f, 0.889448f, 0.220104f, 0.622894f, 0.111496f, 0.458970f, 0.322334f, 0.316501f, 0.482584f, 0.729828f, 0.069183f, 0.879173f, 0.734814f, 0.176499f, 0.939161f, 0.506312f, 0.999809f, 0.197259f, 0.534908f, 0.290248f, 0.304174f, 0.591065f, 0.921719f, 0.805264f, 0.723941f, 0.559174f, 0.922298f, 0.492361f, 0.873832f, 0.833982f, 0.213835f, 0.771225f, 0.012171f, 0.322830f, 0.229567f, 0.506863f, 0.736853f, 0.097676f, 0.514922f, 0.938412f, 0.228647f, 0.677141f, 0.592880f, 0.010064f, 0.475826f, 0.708770f, 0.043975f, 0.879521f, 0.520081f, 0.030661f, 0.224414f, 0.953676f, 0.582320f, 0.107473f, 0.287544f, 0.456704f, 0.020950f, 0.411616f, 0.489459f, 0.243678f, 0.588639f, 0.753240f, 0.235834f, 0.620500f, 0.639622f, 0.948540f, 0.778276f, 0.848345f, 0.490420f, 0.185349f, 0.995815f, 0.129356f, 0.471457f, 0.068093f, 0.943851f, 0.964925f, 0.719389f, 0.349993f, 0.254382f, 0.265303f, 0.127294f, 0.525809f, 0.141817f, 0.316731f, 0.626706f, 0.727544f, 0.024273f, 0.430116f, 0.652125f, 0.853246f, 0.475325f, 0.969206f, 0.265633f, 0.013509f, 0.483753f, 0.256114f, 0.823718f, 0.232773f, 0.310629f, 0.791227f, 0.715143f, 0.558051f, 0.704948f, 0.418637f, 0.005310f, 0.011355f, 0.511222f, 0.083291f, 0.051075f, 0.965517f, 0.859003f, 0.152027f, 0.000664f, 0.941668f, 0.278325f, 0.185898f, 0.691508f, 0.108904f, 0.264650f, 0.975095f, 0.639463f, 0.520678f, 0.397919f, 0.774501f, 0.140957f, 0.967338f, 0.861123f, 0.617657f, 0.042906f, 0.700856f, 0.913284f, 0.524577f, 0.354225f, 0.120277f, 0.754901f, 0.885022f, 0.100252f, 0.758985f, 0.017060f, 0.967055f, 0.615058f, 0.552439f, 0.295950f, 0.929292f, 0.265906f, 0.828147f, 0.985109f, 0.783397f, 0.518990f, 0.066074f, 0.472414f, 0.438256f, 0.202796f, 0.423588f, 0.357758f, 0.163684f, 0.441374f, 0.262800f, 0.522062f, 0.035160f, 0.906231f, 0.816364f, 0.552581f, 0.851809f, 0.962395f, 0.110522f, 0.630832f, 0.997994f, 0.987889f, 0.603323f, 0.128021f, 0.583193f, 0.002065f, 0.198911f, 0.956123f, 0.330441f, 0.638390f, 0.280860f, 0.947822f, 0.728559f, 0.329651f, 0.791761f, 0.108166f, 0.392319f, 0.221218f, 0.683726f, 0.102446f, 0.397026f, 0.276650f, 0.506343f, 0.349898f, 0.706411f, 0.024577f, 0.633987f, 0.230571f, 0.268709f, 0.800256f, 0.955568f, 0.316550f, 0.826805f, 0.103991f, 0.633982f, 0.751032f, 0.155978f, 0.426002f, 0.892707f, 0.103578f, 0.018096f, 0.590585f, 0.435532f, 0.798689f, 0.923456f, 0.299154f, 0.388404f, 0.486272f, 0.588151f, 0.983854f, 0.697330f, 0.389549f, 0.263768f, 0.944626f, 0.135548f, 0.720266f, 0.925395f, 0.664666f, 0.423054f, 0.198991f, 0.367475f, 0.706872f, 0.649534f, 0.927976f, 0.866861f, 0.816151f, 0.911451f, 0.276337f, 0.369524f, 0.379894f, 0.560451f, 0.668218f, 0.286717f, 0.019462f, 0.399222f}; + // {2, 3, 12, 8} + std::vector past_key = {0.308528f, 0.942185f, 0.888265f, 0.860311f, 0.653000f, 0.344289f, 0.548849f, 0.815225f, 0.098610f, 0.801075f, 0.041180f, 0.816421f, 0.807564f, 0.051007f, 0.627161f, 0.502453f, 0.169820f, 0.148379f, 0.773259f, 0.567693f, 0.982999f, 0.982248f, 0.992667f, 0.118616f, 0.938256f, 0.244570f, 0.458212f, 0.757407f, 0.203621f, 0.566312f, 0.185817f, 0.104736f, 0.116559f, 0.357639f, 0.004655f, 0.424854f, 0.664197f, 0.401688f, 0.085795f, 0.062689f, 0.278117f, 0.169313f, 0.965095f, 0.151230f, 0.805462f, 0.586108f, 0.569287f, 0.512081f, 0.971763f, 0.363845f, 0.787916f, 0.555294f, 0.395634f, 0.955466f, 0.598316f, 0.118917f, 0.417539f, 0.781582f, 0.693747f, 0.916340f, 0.259377f, 0.758194f, 0.459875f, 0.573610f, 0.955047f, 0.979286f, 0.861591f, 0.359097f, 0.887701f, 0.638609f, 0.429997f, 0.035743f, 0.770128f, 0.502106f, 0.786188f, 0.748023f, 0.793567f, 0.300651f, 0.800799f, 0.548846f, 0.473326f, 0.675126f, 0.021359f, 0.102317f, 0.292177f, 0.982990f, 0.139746f, 0.330596f, 0.051053f, 0.331269f, 0.320326f, 0.946807f, 0.845154f, 0.382764f, 0.024769f, 0.831031f, 0.660536f, 0.152364f, 0.996071f, 0.100233f, 0.867115f, 0.294266f, 0.435353f, 0.795457f, 0.677508f, 0.937864f, 0.621140f, 0.097810f, 0.884360f, 0.769156f, 0.711870f, 0.053734f, 0.396223f, 0.167436f, 0.821904f, 0.700529f, 0.883078f, 0.966575f, 0.774748f, 0.994233f, 0.614770f, 0.037130f, 0.014252f, 0.342104f, 0.823472f, 0.866135f, 0.960813f, 0.065121f, 0.044571f, 0.913284f, 0.305047f, 0.557987f, 0.982445f, 0.400449f, 0.665871f, 0.400880f, 0.768195f, 0.527715f, 0.237523f, 0.271306f, 0.258059f, 0.532320f, 0.703189f, 0.949280f, 0.694087f, 0.781193f, 0.168926f, 0.374063f, 0.413780f, 0.686380f, 0.295892f, 0.303292f, 0.355889f, 0.810302f, 0.577590f, 0.075277f, 0.078246f, 0.371287f, 0.766591f, 0.688683f, 0.707982f, 0.767210f, 0.287153f, 0.548256f, 0.543353f, 0.739632f, 0.956871f, 0.277990f, 0.793282f, 0.659971f, 0.580238f, 0.774880f, 0.944032f, 0.036691f, 0.147400f, 0.756287f, 0.083791f, 0.516124f, 0.219861f, 0.274296f, 0.701840f, 0.030193f, 0.873319f, 0.444479f, 0.502393f, 0.540048f, 0.645544f, 0.344857f, 0.101107f, 0.318379f, 0.168142f, 0.556133f, 0.318029f, 0.958067f, 0.965734f, 0.620126f, 0.617497f, 0.985379f, 0.887283f, 0.765070f, 0.313591f, 0.365539f, 0.201267f, 0.487148f, 0.990369f, 0.912151f, 0.118349f, 0.025190f, 0.898638f, 0.537170f, 0.200190f, 0.673653f, 0.644223f, 0.122086f, 0.259600f, 0.060078f, 0.209860f, 0.132306f, 0.193236f, 0.685467f, 0.049500f, 0.101855f, 0.134174f, 0.316541f, 0.298750f, 0.255064f, 0.750537f, 0.998023f, 0.533978f, 0.944203f, 0.396610f, 0.106682f, 0.408774f, 0.296128f, 0.493407f, 0.657044f, 0.461050f, 0.935161f, 0.884765f, 0.701978f, 0.489685f, 0.131687f, 0.397014f, 0.704402f, 0.284886f, 0.103988f, 0.907898f, 0.709051f, 0.615276f, 0.792499f, 0.835646f, 0.483459f, 0.881188f, 0.916419f, 0.271551f, 0.607545f, 0.526584f, 0.537946f, 0.937663f, 0.305189f, 0.983434f, 0.902131f, 0.458723f, 0.817453f, 0.769047f, 0.677895f, 0.319834f, 0.196451f, 0.671528f, 0.842973f, 0.016253f, 0.642803f, 0.442873f, 0.898088f, 0.321473f, 0.474185f, 0.514767f, 0.140440f, 0.712892f, 0.830476f, 0.057909f, 0.291389f, 0.038045f, 0.956544f, 0.667169f, 0.964200f, 0.531494f, 0.802069f, 0.374414f, 0.353819f, 0.378268f, 0.657862f, 0.359453f, 0.900367f, 0.983275f, 0.030427f, 0.193623f, 0.112250f, 0.042364f, 0.227741f, 0.446793f, 0.836990f, 0.221824f, 0.493945f, 0.929619f, 0.667215f, 0.798079f, 0.550994f, 0.980466f, 0.588662f, 0.045511f, 0.197983f, 0.404774f, 0.601277f, 0.771931f, 0.413086f, 0.710058f, 0.789869f, 0.317260f, 0.979270f, 0.649656f, 0.880998f, 0.555938f, 0.741603f, 0.770544f, 0.908248f, 0.150350f, 0.558283f, 0.428379f, 0.923159f, 0.105095f, 0.982574f, 0.875451f, 0.073826f, 0.490966f, 0.717560f, 0.738152f, 0.906494f, 0.799865f, 0.310930f, 0.498435f, 0.701786f, 0.138437f, 0.193991f, 0.481042f, 0.298246f, 0.862559f, 0.586277f, 0.348665f, 0.848833f, 0.804878f, 0.998355f, 0.847308f, 0.414457f, 0.127499f, 0.840641f, 0.059758f, 0.350271f, 0.919738f, 0.960766f, 0.640565f, 0.688648f, 0.042454f, 0.514480f, 0.546868f, 0.340101f, 0.068597f, 0.228908f, 0.357984f, 0.435142f, 0.590927f, 0.722392f, 0.317632f, 0.328954f, 0.019692f, 0.040875f, 0.257822f, 0.740245f, 0.628314f, 0.769789f, 0.768919f, 0.856567f, 0.720319f, 0.979011f, 0.898825f, 0.586717f, 0.588158f, 0.034267f, 0.998527f, 0.131576f, 0.740347f, 0.821015f, 0.373055f, 0.196852f, 0.098760f, 0.748606f, 0.452654f, 0.713718f, 0.915408f, 0.146584f, 0.919171f, 0.411626f, 0.305267f, 0.943062f, 0.990652f, 0.198892f, 0.656838f, 0.106495f, 0.650914f, 0.827313f, 0.684499f, 0.417333f, 0.383066f, 0.393122f, 0.589712f, 0.881567f, 0.929066f, 0.053530f, 0.181622f, 0.112224f, 0.193335f, 0.346608f, 0.506532f, 0.629461f, 0.732142f, 0.890112f, 0.989088f, 0.662856f, 0.845365f, 0.778039f, 0.307532f, 0.875692f, 0.042763f, 0.000367f, 0.273733f, 0.462098f, 0.638363f, 0.101770f, 0.673010f, 0.801816f, 0.185313f, 0.415125f, 0.519985f, 0.451807f, 0.799830f, 0.960522f, 0.798953f, 0.077993f, 0.804936f, 0.066596f, 0.235970f, 0.153097f, 0.197519f, 0.528315f, 0.671690f, 0.470321f, 0.959696f, 0.240292f, 0.763140f, 0.870182f, 0.562066f, 0.456223f, 0.596184f, 0.428810f, 0.555194f, 0.416934f, 0.400470f, 0.695346f, 0.092851f, 0.166542f, 0.851198f, 0.771077f, 0.281454f, 0.377269f, 0.926027f, 0.818077f, 0.614346f, 0.221490f, 0.044252f, 0.431258f, 0.672627f, 0.828480f, 0.852689f, 0.032776f, 0.244157f, 0.339095f, 0.188732f, 0.802975f, 0.767466f, 0.516833f, 0.982926f, 0.144059f, 0.899652f, 0.116463f, 0.163182f, 0.696219f, 0.109570f, 0.565845f, 0.420234f, 0.728474f, 0.900675f, 0.769872f, 0.849690f, 0.032945f, 0.310196f, 0.515433f, 0.415953f, 0.231255f, 0.307874f, 0.945431f, 0.294181f, 0.353904f, 0.003710f, 0.845078f, 0.154841f, 0.204144f, 0.255265f, 0.884622f, 0.206451f, 0.797526f, 0.808049f, 0.927021f, 0.115561f, 0.217279f, 0.742898f, 0.196001f, 0.286330f, 0.166742f, 0.172697f, 0.481553f, 0.109683f, 0.321698f, 0.426594f, 0.024548f, 0.388333f, 0.094122f, 0.493579f, 0.825738f, 0.818422f, 0.080449f, 0.601228f, 0.834586f, 0.237973f, 0.761927f, 0.890764f, 0.806124f, 0.107301f, 0.009060f, 0.191724f, 0.270477f, 0.616183f, 0.384273f, 0.703407f, 0.353075f, 0.154425f, 0.312690f, 0.884324f, 0.958532f, 0.207513f, 0.788468f, 0.273349f, 0.887132f, 0.165546f, 0.665960f, 0.084211f, 0.973893f, 0.700633f, 0.841816f, 0.566669f, 0.476801f, 0.621882f, 0.528742f, 0.469384f, 0.759450f, 0.178201f, 0.171172f}; + // {2, 3, 12, 8} + std::vector past_value = {0.431843f, 0.320748f, 0.074125f, 0.844471f, 0.771603f, 0.543921f, 0.979325f, 0.072600f, 0.766669f, 0.266370f, 0.368599f, 0.219279f, 0.789038f, 0.144240f, 0.840017f, 0.661578f, 0.059023f, 0.810982f, 0.627756f, 0.904982f, 0.748722f, 0.561121f, 0.836547f, 0.278050f, 0.546950f, 0.293617f, 0.968204f, 0.226196f, 0.015738f, 0.325855f, 0.502509f, 0.028363f, 0.559248f, 0.874283f, 0.704732f, 0.622968f, 0.955962f, 0.958279f, 0.824266f, 0.607742f, 0.487765f, 0.013316f, 0.606262f, 0.989088f, 0.818101f, 0.340605f, 0.152047f, 0.784059f, 0.743938f, 0.967047f, 0.874842f, 0.555663f, 0.101284f, 0.483501f, 0.313695f, 0.512408f, 0.301702f, 0.861823f, 0.844327f, 0.315465f, 0.599581f, 0.430181f, 0.909093f, 0.187361f, 0.697728f, 0.970375f, 0.175276f, 0.201966f, 0.693723f, 0.779154f, 0.490549f, 0.609686f, 0.212682f, 0.476614f, 0.112072f, 0.321422f, 0.284780f, 0.444625f, 0.930126f, 0.181268f, 0.401388f, 0.615597f, 0.946557f, 0.133148f, 0.917877f, 0.081054f, 0.480741f, 0.454590f, 0.209603f, 0.347460f, 0.454165f, 0.865211f, 0.955064f, 0.518926f, 0.870100f, 0.608172f, 0.349087f, 0.194194f, 0.413135f, 0.522824f, 0.044443f, 0.145841f, 0.600184f, 0.225002f, 0.837326f, 0.326942f, 0.104834f, 0.083531f, 0.937123f, 0.118020f, 0.140910f, 0.862666f, 0.254288f, 0.665951f, 0.816726f, 0.607181f, 0.957489f, 0.708883f, 0.112752f, 0.558410f, 0.718186f, 0.801957f, 0.026321f, 0.718879f, 0.825681f, 0.746834f, 0.512349f, 0.458021f, 0.549419f, 0.704644f, 0.922914f, 0.617035f, 0.887834f, 0.701257f, 0.068336f, 0.500828f, 0.286486f, 0.285175f, 0.355928f, 0.314733f, 0.578610f, 0.683601f, 0.268749f, 0.129763f, 0.058809f, 0.575753f, 0.186130f, 0.009248f, 0.927753f, 0.537140f, 0.092448f, 0.842921f, 0.983203f, 0.448601f, 0.042490f, 0.117546f, 0.381654f, 0.885523f, 0.148039f, 0.823990f, 0.014976f, 0.457389f, 0.644397f, 0.060379f, 0.614763f, 0.944404f, 0.160260f, 0.729611f, 0.609094f, 0.185116f, 0.006203f, 0.009284f, 0.532092f, 0.942779f, 0.644299f, 0.714300f, 0.493865f, 0.581889f, 0.126368f, 0.876821f, 0.760793f, 0.998199f, 0.297723f, 0.227018f, 0.125162f, 0.964210f, 0.780885f, 0.166325f, 0.552686f, 0.413768f, 0.151486f, 0.162073f, 0.963470f, 0.304964f, 0.941439f, 0.075611f, 0.460803f, 0.129619f, 0.004787f, 0.553766f, 0.113894f, 0.722025f, 0.698116f, 0.176333f, 0.941742f, 0.721043f, 0.297970f, 0.709234f, 0.731930f, 0.342226f, 0.375589f, 0.359107f, 0.616618f, 0.900410f, 0.173193f, 0.875200f, 0.027653f, 0.660339f, 0.414439f, 0.791282f, 0.721198f, 0.480108f, 0.643864f, 0.501773f, 0.811518f, 0.476084f, 0.523156f, 0.250521f, 0.605043f, 0.302905f, 0.577284f, 0.169678f, 0.159469f, 0.417030f, 0.426820f, 0.268109f, 0.131597f, 0.039211f, 0.025232f, 0.271550f, 0.461853f, 0.726243f, 0.474872f, 0.904051f, 0.035220f, 0.180661f, 0.338515f, 0.577496f, 0.852736f, 0.350202f, 0.267989f, 0.061889f, 0.821303f, 0.379666f, 0.571550f, 0.983555f, 0.001595f, 0.145450f, 0.779111f, 0.805128f, 0.769247f, 0.536999f, 0.978857f, 0.396185f, 0.601944f, 0.063369f, 0.409857f, 0.722500f, 0.238739f, 0.943828f, 0.686783f, 0.287575f, 0.768999f, 0.083165f, 0.974774f, 0.049285f, 0.933456f, 0.252854f, 0.757824f, 0.000074f, 0.254240f, 0.749101f, 0.532336f, 0.114952f, 0.393630f, 0.375549f, 0.568162f, 0.667977f, 0.840830f, 0.497231f, 0.392022f, 0.143977f, 0.804823f, 0.713370f, 0.408677f, 0.518432f, 0.665183f, 0.164806f, 0.027198f, 0.317504f, 0.595585f, 0.486606f, 0.692555f, 0.819690f, 0.488442f, 0.134267f, 0.850628f, 0.574990f, 0.739937f, 0.704665f, 0.968212f, 0.295307f, 0.705307f, 0.365676f, 0.395411f, 0.230595f, 0.344010f, 0.948297f, 0.292571f, 0.245991f, 0.583138f, 0.258036f, 0.473386f, 0.834176f, 0.230400f, 0.426691f, 0.610490f, 0.545629f, 0.974723f, 0.680370f, 0.739946f, 0.966956f, 0.414438f, 0.355380f, 0.043862f, 0.184204f, 0.237190f, 0.183504f, 0.754784f, 0.535883f, 0.667634f, 0.820462f, 0.230774f, 0.325924f, 0.708360f, 0.392759f, 0.029271f, 0.434955f, 0.908273f, 0.409021f, 0.332249f, 0.989525f, 0.644416f, 0.365998f, 0.102020f, 0.787849f, 0.708075f, 0.921916f, 0.217276f, 0.114924f, 0.724073f, 0.203396f, 0.176104f, 0.319807f, 0.816825f, 0.539537f, 0.045850f, 0.463895f, 0.683980f, 0.538368f, 0.572450f, 0.224777f, 0.847739f, 0.561399f, 0.713246f, 0.981864f, 0.428199f, 0.881067f, 0.007281f, 0.033407f, 0.590280f, 0.311449f, 0.248277f, 0.277935f, 0.318403f, 0.728948f, 0.569196f, 0.789036f, 0.830197f, 0.842935f, 0.414644f, 0.421273f, 0.926266f, 0.661764f, 0.080467f, 0.542187f, 0.356007f, 0.987435f, 0.013655f, 0.612181f, 0.723623f, 0.288907f, 0.973642f, 0.859537f, 0.915653f, 0.019232f, 0.569872f, 0.294650f, 0.849029f, 0.632850f, 0.538877f, 0.114588f, 0.540223f, 0.631904f, 0.955912f, 0.585051f, 0.967401f, 0.961606f, 0.650200f, 0.505908f, 0.466022f, 0.890379f, 0.028257f, 0.113808f, 0.102072f, 0.756935f, 0.339651f, 0.637969f, 0.603783f, 0.385828f, 0.531568f, 0.645139f, 0.940950f, 0.575634f, 0.614367f, 0.067856f, 0.952216f, 0.528082f, 0.801273f, 0.050291f, 0.420910f, 0.256975f, 0.266976f, 0.791454f, 0.623867f, 0.439745f, 0.010586f, 0.964928f, 0.962023f, 0.217552f, 0.041346f, 0.530199f, 0.951411f, 0.910396f, 0.584663f, 0.303549f, 0.329961f, 0.897914f, 0.491784f, 0.131116f, 0.248425f, 0.276795f, 0.123547f, 0.463044f, 0.916051f, 0.668783f, 0.072474f, 0.005495f, 0.276248f, 0.362693f, 0.776750f, 0.967006f, 0.387567f, 0.686690f, 0.994902f, 0.745667f, 0.636190f, 0.078075f, 0.323215f, 0.913392f, 0.201005f, 0.843590f, 0.696324f, 0.366324f, 0.529174f, 0.542806f, 0.714054f, 0.516556f, 0.133076f, 0.773455f, 0.406273f, 0.963094f, 0.283514f, 0.263079f, 0.333507f, 0.572317f, 0.894870f, 0.176282f, 0.279679f, 0.581680f, 0.454334f, 0.447323f, 0.820734f, 0.923878f, 0.481307f, 0.687352f, 0.801059f, 0.518366f, 0.294316f, 0.638085f, 0.585109f, 0.901563f, 0.052407f, 0.910131f, 0.534432f, 0.015676f, 0.344702f, 0.724334f, 0.488433f, 0.980159f, 0.422610f, 0.326635f, 0.821672f, 0.547907f, 0.682327f, 0.805702f, 0.671428f, 0.422408f, 0.124796f, 0.580248f, 0.897433f, 0.418892f, 0.910725f, 0.503528f, 0.620842f, 0.832989f, 0.564597f, 0.090969f, 0.980979f, 0.245849f, 0.710505f, 0.505113f, 0.478773f, 0.243941f, 0.722151f, 0.112788f, 0.990453f, 0.845374f, 0.534509f, 0.424553f, 0.286465f, 0.501591f, 0.879417f, 0.275006f, 0.500537f, 0.234550f, 0.337149f, 0.190261f, 0.990539f, 0.571497f, 0.732815f, 0.098250f, 0.366118f, 0.892640f, 0.084438f, 0.165483f, 0.625418f, 0.622789f, 0.838227f, 0.935493f, 0.141986f, 0.259374f, 0.427461f, 0.000903f, 0.069814f, 0.226491f}; ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); @@ -1181,14 +1214,15 @@ TEST(AttentionTest, Attention4DWithMask4DPastAndPresentQkMatmul) { ASSERT_EQ(past_key.size(), batch_size * kv_num_heads * past_sequence_length * head_size); ASSERT_EQ(past_value.size(), batch_size * kv_num_heads * past_sequence_length * v_head_size); - // {2, 3, 4, 4} - std::vector y = {-0.385742f, -0.379327f, -0.372911f, -0.366496f, -0.385554f, -0.379139f, -0.372723f, -0.366308f, -0.385366f, -0.378950f, -0.372535f, -0.366119f, -0.385178f, -0.378762f, -0.372347f, -0.365931f, -0.218323f, -0.211907f, -0.205492f, -0.199076f, -0.218134f, -0.211719f, -0.205304f, -0.198888f, -0.217946f, -0.211531f, -0.205115f, -0.198700f, -0.217758f, -0.211342f, -0.204927f, -0.198512f, -0.050903f, -0.044487f, -0.038072f, -0.031657f, -0.050715f, -0.044299f, -0.037884f, -0.031468f, -0.050526f, -0.044111f, -0.037695f, -0.031280f, -0.050338f, -0.043922f, -0.037507f, -0.031092f, 0.116517f, 0.122932f, 0.129348f, 0.135763f, 0.116705f, 0.123121f, 0.129536f, 0.135952f, 0.116894f, 0.123309f, 0.129724f, 0.136140f, 0.117082f, 0.123497f, 0.129913f, 0.136328f, 0.283937f, 0.290352f, 0.296768f, 0.303183f, 0.284125f, 0.290540f, 0.296956f, 0.303371f, 0.284313f, 0.290729f, 0.297144f, 0.303559f, 0.284501f, 0.290917f, 0.297332f, 0.303747f, 0.451356f, 0.457772f, 0.464187f, 0.470602f, 0.451544f, 0.457960f, 0.464375f, 0.470790f, 0.451732f, 0.458148f, 0.464563f, 0.470978f, 0.451920f, 0.458336f, 0.464751f, 0.471166f}; - // {2, 3, 13, 4} - std::vector present_key = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; + // {2, 3, 4, 8} + std::vector y = {0.431265f, 0.558994f, 0.492979f, 0.535281f, 0.609591f, 0.466737f, 0.692090f, 0.412591f, 0.468058f, 0.623595f, 0.468127f, 0.483497f, 0.577278f, 0.512802f, 0.639767f, 0.427679f, 0.422704f, 0.532822f, 0.449594f, 0.560548f, 0.608427f, 0.476187f, 0.695694f, 0.425740f, 0.447270f, 0.528366f, 0.506840f, 0.501836f, 0.547248f, 0.457381f, 0.583533f, 0.471707f, 0.414727f, 0.517263f, 0.342732f, 0.363543f, 0.677046f, 0.664675f, 0.271455f, 0.479982f, 0.438313f, 0.537211f, 0.342649f, 0.402609f, 0.660072f, 0.631518f, 0.266481f, 0.501402f, 0.458457f, 0.519536f, 0.434125f, 0.443849f, 0.614893f, 0.636419f, 0.310940f, 0.497030f, 0.433312f, 0.522457f, 0.417441f, 0.405432f, 0.617509f, 0.592985f, 0.310558f, 0.490073f, 0.499459f, 0.430465f, 0.601451f, 0.404111f, 0.502848f, 0.415186f, 0.440655f, 0.478187f, 0.536562f, 0.376663f, 0.527310f, 0.363608f, 0.443744f, 0.476396f, 0.453812f, 0.498910f, 0.483497f, 0.433209f, 0.541590f, 0.366029f, 0.513807f, 0.477506f, 0.492110f, 0.527910f, 0.471458f, 0.419741f, 0.536529f, 0.407806f, 0.512188f, 0.467064f, 0.496260f, 0.519270f, 0.683252f, 0.426643f, 0.425275f, 0.457410f, 0.611686f, 0.591234f, 0.394568f, 0.446171f, 0.637484f, 0.426481f, 0.346779f, 0.466867f, 0.585075f, 0.558250f, 0.387627f, 0.507636f, 0.658808f, 0.467355f, 0.496107f, 0.556756f, 0.513309f, 0.520842f, 0.411220f, 0.451704f, 0.661693f, 0.463543f, 0.421647f, 0.486068f, 0.552701f, 0.484705f, 0.412050f, 0.449818f, 0.637941f, 0.564086f, 0.543446f, 0.530844f, 0.627347f, 0.520370f, 0.389963f, 0.520054f, 0.574335f, 0.604007f, 0.468559f, 0.473710f, 0.559229f, 0.504183f, 0.453090f, 0.564618f, 0.568083f, 0.541180f, 0.491888f, 0.485970f, 0.564150f, 0.506989f, 0.421426f, 0.544228f, 0.616426f, 0.467555f, 0.529898f, 0.487372f, 0.574411f, 0.471969f, 0.388121f, 0.485012f, 0.533687f, 0.523210f, 0.560021f, 0.490233f, 0.443149f, 0.420163f, 0.538998f, 0.606965f, 0.586616f, 0.478324f, 0.572142f, 0.517933f, 0.441955f, 0.411890f, 0.550505f, 0.604577f, 0.541173f, 0.473423f, 0.505749f, 0.473388f, 0.389025f, 0.498730f, 0.507861f, 0.584389f, 0.519963f, 0.461030f, 0.576878f, 0.471281f, 0.461238f, 0.496673f, 0.509573f, 0.568405f}; // {2, 3, 18, 8} - std::vector present_value = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; - // {2, 3, 4, 13} - std::vector qk_matmul = {0.391336f, 0.370435f, 0.349534f, 0.328633f, 0.307732f, 0.286831f, 0.265930f, 0.390055f, 0.365671f, 0.341286f, 0.316902f, 0.292517f, 0.268133f, 0.354201f, 0.335284f, 0.316367f, 0.297450f, 0.278534f, 0.259617f, 0.240700f, 0.353045f, 0.330975f, 0.308905f, 0.286836f, 0.264766f, 0.242696f, 0.317066f, 0.300134f, 0.283201f, 0.266268f, 0.249335f, 0.232403f, 0.215470f, 0.316034f, 0.296279f, 0.276524f, 0.256769f, 0.237014f, 0.217260f, 0.279932f, 0.264983f, 0.250034f, 0.235086f, 0.220137f, 0.205189f, 0.190240f, 0.279023f, 0.261583f, 0.244143f, 0.226703f, 0.209263f, 0.191823f, 0.152046f, 0.139081f, 0.126117f, 0.113152f, 0.100188f, 0.087223f, 0.074259f, 0.151261f, 0.136136f, 0.121011f, 0.105885f, 0.090760f, 0.075635f, 0.128800f, 0.117819f, 0.106839f, 0.095859f, 0.084878f, 0.073898f, 0.062918f, 0.128139f, 0.115329f, 0.102518f, 0.089708f, 0.076898f, 0.064087f, 0.105554f, 0.096558f, 0.087561f, 0.078565f, 0.069569f, 0.060573f, 0.051577f, 0.105017f, 0.094522f, 0.084026f, 0.073531f, 0.063035f, 0.052539f, 0.082308f, 0.075296f, 0.068284f, 0.061272f, 0.054260f, 0.047248f, 0.040235f, 0.081896f, 0.073715f, 0.065534f, 0.057353f, 0.049172f, 0.040992f, 0.023866f, 0.018838f, 0.013810f, 0.008783f, 0.003755f, -0.001273f, -0.006301f, 0.023578f, 0.017712f, 0.011846f, 0.005980f, 0.000114f, -0.005752f, 0.014509f, 0.011466f, 0.008422f, 0.005378f, 0.002334f, -0.000710f, -0.003754f, 0.014345f, 0.010794f, 0.007243f, 0.003692f, 0.000140f, -0.003411f, 0.005152f, 0.004093f, 0.003033f, 0.001973f, 0.000914f, -0.000146f, -0.001206f, 0.005112f, 0.003876f, 0.002639f, 0.001403f, 0.000167f, -0.001070f, -0.004204f, -0.003280f, -0.002356f, -0.001431f, -0.000507f, 0.000418f, 0.001342f, -0.004121f, -0.003042f, -0.001964f, -0.000885f, 0.000193f, 0.001272f, 0.006798f, 0.009707f, 0.012616f, 0.015524f, 0.018433f, 0.021341f, 0.024250f, 0.007006f, 0.010399f, 0.013793f, 0.017186f, 0.020579f, 0.023973f, 0.011330f, 0.016223f, 0.021116f, 0.026008f, 0.030901f, 0.035794f, 0.040686f, 0.011662f, 0.017370f, 0.023078f, 0.028786f, 0.034494f, 0.040203f, 0.015862f, 0.022739f, 0.029616f, 0.036493f, 0.043369f, 0.050246f, 0.057123f, 0.016318f, 0.024341f, 0.032364f, 0.040387f, 0.048410f, 0.056433f, 0.020394f, 0.029255f, 0.038116f, 0.046977f, 0.055838f, 0.064699f, 0.073560f, 0.020974f, 0.031312f, 0.041649f, 0.051987f, 0.062325f, 0.072663f, 0.100842f, 0.111687f, 0.122532f, 0.133377f, 0.144222f, 0.155067f, 0.165912f, 0.101545f, 0.114198f, 0.126850f, 0.139503f, 0.152155f, 0.164808f, 0.119262f, 0.132092f, 0.144921f, 0.157750f, 0.170579f, 0.183408f, 0.196237f, 0.120090f, 0.135057f, 0.150025f, 0.164992f, 0.179960f, 0.194927f, 0.137683f, 0.152496f, 0.167310f, 0.182123f, 0.196936f, 0.211750f, 0.226563f, 0.138635f, 0.155917f, 0.173199f, 0.190481f, 0.207764f, 0.225046f, 0.156104f, 0.172901f, 0.189699f, 0.206496f, 0.223294f, 0.240091f, 0.256889f, 0.157180f, 0.176777f, 0.196374f, 0.215971f, 0.235568f, 0.255165f, 0.305996f, 0.324777f, 0.343559f, 0.362340f, 0.381122f, 0.399904f, 0.418685f, 0.307195f, 0.329107f, 0.351019f, 0.372931f, 0.394843f, 0.416755f, 0.338305f, 0.359071f, 0.379837f, 0.400603f, 0.421368f, 0.442134f, 0.462900f, 0.339629f, 0.363856f, 0.388082f, 0.412309f, 0.436536f, 0.460762f, 0.370615f, 0.393365f, 0.416115f, 0.438865f, 0.461614f, 0.484364f, 0.507114f, 0.372063f, 0.398604f, 0.425146f, 0.451687f, 0.478229f, 0.504770f, 0.402925f, 0.427659f, 0.452393f, 0.477127f, 0.501861f, 0.526595f, 0.551329f, 0.404497f, 0.433353f, 0.462209f, 0.491065f, 0.519922f, 0.548778f}; + std::vector present_key = {0.308528f, 0.942185f, 0.888265f, 0.860311f, 0.653000f, 0.344289f, 0.548849f, 0.815225f, 0.098610f, 0.801075f, 0.041180f, 0.816421f, 0.807564f, 0.051007f, 0.627161f, 0.502453f, 0.169820f, 0.148379f, 0.773259f, 0.567693f, 0.982999f, 0.982248f, 0.992667f, 0.118616f, 0.938256f, 0.244570f, 0.458212f, 0.757407f, 0.203621f, 0.566312f, 0.185817f, 0.104736f, 0.116559f, 0.357639f, 0.004655f, 0.424854f, 0.664197f, 0.401688f, 0.085795f, 0.062689f, 0.278117f, 0.169313f, 0.965095f, 0.151230f, 0.805462f, 0.586108f, 0.569287f, 0.512081f, 0.971763f, 0.363845f, 0.787916f, 0.555294f, 0.395634f, 0.955466f, 0.598316f, 0.118917f, 0.417539f, 0.781582f, 0.693747f, 0.916340f, 0.259377f, 0.758194f, 0.459875f, 0.573610f, 0.955047f, 0.979286f, 0.861591f, 0.359097f, 0.887701f, 0.638609f, 0.429997f, 0.035743f, 0.770128f, 0.502106f, 0.786188f, 0.748023f, 0.793567f, 0.300651f, 0.800799f, 0.548846f, 0.473326f, 0.675126f, 0.021359f, 0.102317f, 0.292177f, 0.982990f, 0.139746f, 0.330596f, 0.051053f, 0.331269f, 0.320326f, 0.946807f, 0.845154f, 0.382764f, 0.024769f, 0.831031f, 0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.660536f, 0.152364f, 0.996071f, 0.100233f, 0.867115f, 0.294266f, 0.435353f, 0.795457f, 0.677508f, 0.937864f, 0.621140f, 0.097810f, 0.884360f, 0.769156f, 0.711870f, 0.053734f, 0.396223f, 0.167436f, 0.821904f, 0.700529f, 0.883078f, 0.966575f, 0.774748f, 0.994233f, 0.614770f, 0.037130f, 0.014252f, 0.342104f, 0.823472f, 0.866135f, 0.960813f, 0.065121f, 0.044571f, 0.913284f, 0.305047f, 0.557987f, 0.982445f, 0.400449f, 0.665871f, 0.400880f, 0.768195f, 0.527715f, 0.237523f, 0.271306f, 0.258059f, 0.532320f, 0.703189f, 0.949280f, 0.694087f, 0.781193f, 0.168926f, 0.374063f, 0.413780f, 0.686380f, 0.295892f, 0.303292f, 0.355889f, 0.810302f, 0.577590f, 0.075277f, 0.078246f, 0.371287f, 0.766591f, 0.688683f, 0.707982f, 0.767210f, 0.287153f, 0.548256f, 0.543353f, 0.739632f, 0.956871f, 0.277990f, 0.793282f, 0.659971f, 0.580238f, 0.774880f, 0.944032f, 0.036691f, 0.147400f, 0.756287f, 0.083791f, 0.516124f, 0.219861f, 0.274296f, 0.701840f, 0.030193f, 0.873319f, 0.444479f, 0.502393f, 0.540048f, 0.645544f, 0.344857f, 0.101107f, 0.318379f, 0.168142f, 0.556133f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.318029f, 0.958067f, 0.965734f, 0.620126f, 0.617497f, 0.985379f, 0.887283f, 0.765070f, 0.313591f, 0.365539f, 0.201267f, 0.487148f, 0.990369f, 0.912151f, 0.118349f, 0.025190f, 0.898638f, 0.537170f, 0.200190f, 0.673653f, 0.644223f, 0.122086f, 0.259600f, 0.060078f, 0.209860f, 0.132306f, 0.193236f, 0.685467f, 0.049500f, 0.101855f, 0.134174f, 0.316541f, 0.298750f, 0.255064f, 0.750537f, 0.998023f, 0.533978f, 0.944203f, 0.396610f, 0.106682f, 0.408774f, 0.296128f, 0.493407f, 0.657044f, 0.461050f, 0.935161f, 0.884765f, 0.701978f, 0.489685f, 0.131687f, 0.397014f, 0.704402f, 0.284886f, 0.103988f, 0.907898f, 0.709051f, 0.615276f, 0.792499f, 0.835646f, 0.483459f, 0.881188f, 0.916419f, 0.271551f, 0.607545f, 0.526584f, 0.537946f, 0.937663f, 0.305189f, 0.983434f, 0.902131f, 0.458723f, 0.817453f, 0.769047f, 0.677895f, 0.319834f, 0.196451f, 0.671528f, 0.842973f, 0.016253f, 0.642803f, 0.442873f, 0.898088f, 0.321473f, 0.474185f, 0.514767f, 0.140440f, 0.712892f, 0.830476f, 0.057909f, 0.291389f, 0.038045f, 0.956544f, 0.667169f, 0.964200f, 0.531494f, 0.802069f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.374414f, 0.353819f, 0.378268f, 0.657862f, 0.359453f, 0.900367f, 0.983275f, 0.030427f, 0.193623f, 0.112250f, 0.042364f, 0.227741f, 0.446793f, 0.836990f, 0.221824f, 0.493945f, 0.929619f, 0.667215f, 0.798079f, 0.550994f, 0.980466f, 0.588662f, 0.045511f, 0.197983f, 0.404774f, 0.601277f, 0.771931f, 0.413086f, 0.710058f, 0.789869f, 0.317260f, 0.979270f, 0.649656f, 0.880998f, 0.555938f, 0.741603f, 0.770544f, 0.908248f, 0.150350f, 0.558283f, 0.428379f, 0.923159f, 0.105095f, 0.982574f, 0.875451f, 0.073826f, 0.490966f, 0.717560f, 0.738152f, 0.906494f, 0.799865f, 0.310930f, 0.498435f, 0.701786f, 0.138437f, 0.193991f, 0.481042f, 0.298246f, 0.862559f, 0.586277f, 0.348665f, 0.848833f, 0.804878f, 0.998355f, 0.847308f, 0.414457f, 0.127499f, 0.840641f, 0.059758f, 0.350271f, 0.919738f, 0.960766f, 0.640565f, 0.688648f, 0.042454f, 0.514480f, 0.546868f, 0.340101f, 0.068597f, 0.228908f, 0.357984f, 0.435142f, 0.590927f, 0.722392f, 0.317632f, 0.328954f, 0.019692f, 0.040875f, 0.257822f, 0.740245f, 0.628314f, 0.769789f, 0.768919f, 0.856567f, 0.720319f, 0.979011f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.898825f, 0.586717f, 0.588158f, 0.034267f, 0.998527f, 0.131576f, 0.740347f, 0.821015f, 0.373055f, 0.196852f, 0.098760f, 0.748606f, 0.452654f, 0.713718f, 0.915408f, 0.146584f, 0.919171f, 0.411626f, 0.305267f, 0.943062f, 0.990652f, 0.198892f, 0.656838f, 0.106495f, 0.650914f, 0.827313f, 0.684499f, 0.417333f, 0.383066f, 0.393122f, 0.589712f, 0.881567f, 0.929066f, 0.053530f, 0.181622f, 0.112224f, 0.193335f, 0.346608f, 0.506532f, 0.629461f, 0.732142f, 0.890112f, 0.989088f, 0.662856f, 0.845365f, 0.778039f, 0.307532f, 0.875692f, 0.042763f, 0.000367f, 0.273733f, 0.462098f, 0.638363f, 0.101770f, 0.673010f, 0.801816f, 0.185313f, 0.415125f, 0.519985f, 0.451807f, 0.799830f, 0.960522f, 0.798953f, 0.077993f, 0.804936f, 0.066596f, 0.235970f, 0.153097f, 0.197519f, 0.528315f, 0.671690f, 0.470321f, 0.959696f, 0.240292f, 0.763140f, 0.870182f, 0.562066f, 0.456223f, 0.596184f, 0.428810f, 0.555194f, 0.416934f, 0.400470f, 0.695346f, 0.092851f, 0.166542f, 0.851198f, 0.771077f, 0.281454f, 0.377269f, 0.926027f, 0.818077f, 0.614346f, 0.221490f, 0.044252f, 0.431258f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.672627f, 0.828480f, 0.852689f, 0.032776f, 0.244157f, 0.339095f, 0.188732f, 0.802975f, 0.767466f, 0.516833f, 0.982926f, 0.144059f, 0.899652f, 0.116463f, 0.163182f, 0.696219f, 0.109570f, 0.565845f, 0.420234f, 0.728474f, 0.900675f, 0.769872f, 0.849690f, 0.032945f, 0.310196f, 0.515433f, 0.415953f, 0.231255f, 0.307874f, 0.945431f, 0.294181f, 0.353904f, 0.003710f, 0.845078f, 0.154841f, 0.204144f, 0.255265f, 0.884622f, 0.206451f, 0.797526f, 0.808049f, 0.927021f, 0.115561f, 0.217279f, 0.742898f, 0.196001f, 0.286330f, 0.166742f, 0.172697f, 0.481553f, 0.109683f, 0.321698f, 0.426594f, 0.024548f, 0.388333f, 0.094122f, 0.493579f, 0.825738f, 0.818422f, 0.080449f, 0.601228f, 0.834586f, 0.237973f, 0.761927f, 0.890764f, 0.806124f, 0.107301f, 0.009060f, 0.191724f, 0.270477f, 0.616183f, 0.384273f, 0.703407f, 0.353075f, 0.154425f, 0.312690f, 0.884324f, 0.958532f, 0.207513f, 0.788468f, 0.273349f, 0.887132f, 0.165546f, 0.665960f, 0.084211f, 0.973893f, 0.700633f, 0.841816f, 0.566669f, 0.476801f, 0.621882f, 0.528742f, 0.469384f, 0.759450f, 0.178201f, 0.171172f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f}; + // {2, 3, 18, 8} + std::vector present_value = {0.431843f, 0.320748f, 0.074125f, 0.844471f, 0.771603f, 0.543921f, 0.979325f, 0.072600f, 0.766669f, 0.266370f, 0.368599f, 0.219279f, 0.789038f, 0.144240f, 0.840017f, 0.661578f, 0.059023f, 0.810982f, 0.627756f, 0.904982f, 0.748722f, 0.561121f, 0.836547f, 0.278050f, 0.546950f, 0.293617f, 0.968204f, 0.226196f, 0.015738f, 0.325855f, 0.502509f, 0.028363f, 0.559248f, 0.874283f, 0.704732f, 0.622968f, 0.955962f, 0.958279f, 0.824266f, 0.607742f, 0.487765f, 0.013316f, 0.606262f, 0.989088f, 0.818101f, 0.340605f, 0.152047f, 0.784059f, 0.743938f, 0.967047f, 0.874842f, 0.555663f, 0.101284f, 0.483501f, 0.313695f, 0.512408f, 0.301702f, 0.861823f, 0.844327f, 0.315465f, 0.599581f, 0.430181f, 0.909093f, 0.187361f, 0.697728f, 0.970375f, 0.175276f, 0.201966f, 0.693723f, 0.779154f, 0.490549f, 0.609686f, 0.212682f, 0.476614f, 0.112072f, 0.321422f, 0.284780f, 0.444625f, 0.930126f, 0.181268f, 0.401388f, 0.615597f, 0.946557f, 0.133148f, 0.917877f, 0.081054f, 0.480741f, 0.454590f, 0.209603f, 0.347460f, 0.454165f, 0.865211f, 0.955064f, 0.518926f, 0.870100f, 0.608172f, 0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.349087f, 0.194194f, 0.413135f, 0.522824f, 0.044443f, 0.145841f, 0.600184f, 0.225002f, 0.837326f, 0.326942f, 0.104834f, 0.083531f, 0.937123f, 0.118020f, 0.140910f, 0.862666f, 0.254288f, 0.665951f, 0.816726f, 0.607181f, 0.957489f, 0.708883f, 0.112752f, 0.558410f, 0.718186f, 0.801957f, 0.026321f, 0.718879f, 0.825681f, 0.746834f, 0.512349f, 0.458021f, 0.549419f, 0.704644f, 0.922914f, 0.617035f, 0.887834f, 0.701257f, 0.068336f, 0.500828f, 0.286486f, 0.285175f, 0.355928f, 0.314733f, 0.578610f, 0.683601f, 0.268749f, 0.129763f, 0.058809f, 0.575753f, 0.186130f, 0.009248f, 0.927753f, 0.537140f, 0.092448f, 0.842921f, 0.983203f, 0.448601f, 0.042490f, 0.117546f, 0.381654f, 0.885523f, 0.148039f, 0.823990f, 0.014976f, 0.457389f, 0.644397f, 0.060379f, 0.614763f, 0.944404f, 0.160260f, 0.729611f, 0.609094f, 0.185116f, 0.006203f, 0.009284f, 0.532092f, 0.942779f, 0.644299f, 0.714300f, 0.493865f, 0.581889f, 0.126368f, 0.876821f, 0.760793f, 0.998199f, 0.297723f, 0.227018f, 0.125162f, 0.964210f, 0.780885f, 0.166325f, 0.552686f, 0.413768f, 0.151486f, 0.162073f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f, 0.963470f, 0.304964f, 0.941439f, 0.075611f, 0.460803f, 0.129619f, 0.004787f, 0.553766f, 0.113894f, 0.722025f, 0.698116f, 0.176333f, 0.941742f, 0.721043f, 0.297970f, 0.709234f, 0.731930f, 0.342226f, 0.375589f, 0.359107f, 0.616618f, 0.900410f, 0.173193f, 0.875200f, 0.027653f, 0.660339f, 0.414439f, 0.791282f, 0.721198f, 0.480108f, 0.643864f, 0.501773f, 0.811518f, 0.476084f, 0.523156f, 0.250521f, 0.605043f, 0.302905f, 0.577284f, 0.169678f, 0.159469f, 0.417030f, 0.426820f, 0.268109f, 0.131597f, 0.039211f, 0.025232f, 0.271550f, 0.461853f, 0.726243f, 0.474872f, 0.904051f, 0.035220f, 0.180661f, 0.338515f, 0.577496f, 0.852736f, 0.350202f, 0.267989f, 0.061889f, 0.821303f, 0.379666f, 0.571550f, 0.983555f, 0.001595f, 0.145450f, 0.779111f, 0.805128f, 0.769247f, 0.536999f, 0.978857f, 0.396185f, 0.601944f, 0.063369f, 0.409857f, 0.722500f, 0.238739f, 0.943828f, 0.686783f, 0.287575f, 0.768999f, 0.083165f, 0.974774f, 0.049285f, 0.933456f, 0.252854f, 0.757824f, 0.000074f, 0.254240f, 0.749101f, 0.532336f, 0.114952f, 0.393630f, 0.375549f, 0.568162f, 0.667977f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.840830f, 0.497231f, 0.392022f, 0.143977f, 0.804823f, 0.713370f, 0.408677f, 0.518432f, 0.665183f, 0.164806f, 0.027198f, 0.317504f, 0.595585f, 0.486606f, 0.692555f, 0.819690f, 0.488442f, 0.134267f, 0.850628f, 0.574990f, 0.739937f, 0.704665f, 0.968212f, 0.295307f, 0.705307f, 0.365676f, 0.395411f, 0.230595f, 0.344010f, 0.948297f, 0.292571f, 0.245991f, 0.583138f, 0.258036f, 0.473386f, 0.834176f, 0.230400f, 0.426691f, 0.610490f, 0.545629f, 0.974723f, 0.680370f, 0.739946f, 0.966956f, 0.414438f, 0.355380f, 0.043862f, 0.184204f, 0.237190f, 0.183504f, 0.754784f, 0.535883f, 0.667634f, 0.820462f, 0.230774f, 0.325924f, 0.708360f, 0.392759f, 0.029271f, 0.434955f, 0.908273f, 0.409021f, 0.332249f, 0.989525f, 0.644416f, 0.365998f, 0.102020f, 0.787849f, 0.708075f, 0.921916f, 0.217276f, 0.114924f, 0.724073f, 0.203396f, 0.176104f, 0.319807f, 0.816825f, 0.539537f, 0.045850f, 0.463895f, 0.683980f, 0.538368f, 0.572450f, 0.224777f, 0.847739f, 0.561399f, 0.713246f, 0.981864f, 0.428199f, 0.881067f, 0.007281f, 0.033407f, 0.590280f, 0.311449f, 0.248277f, 0.277935f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.318403f, 0.728948f, 0.569196f, 0.789036f, 0.830197f, 0.842935f, 0.414644f, 0.421273f, 0.926266f, 0.661764f, 0.080467f, 0.542187f, 0.356007f, 0.987435f, 0.013655f, 0.612181f, 0.723623f, 0.288907f, 0.973642f, 0.859537f, 0.915653f, 0.019232f, 0.569872f, 0.294650f, 0.849029f, 0.632850f, 0.538877f, 0.114588f, 0.540223f, 0.631904f, 0.955912f, 0.585051f, 0.967401f, 0.961606f, 0.650200f, 0.505908f, 0.466022f, 0.890379f, 0.028257f, 0.113808f, 0.102072f, 0.756935f, 0.339651f, 0.637969f, 0.603783f, 0.385828f, 0.531568f, 0.645139f, 0.940950f, 0.575634f, 0.614367f, 0.067856f, 0.952216f, 0.528082f, 0.801273f, 0.050291f, 0.420910f, 0.256975f, 0.266976f, 0.791454f, 0.623867f, 0.439745f, 0.010586f, 0.964928f, 0.962023f, 0.217552f, 0.041346f, 0.530199f, 0.951411f, 0.910396f, 0.584663f, 0.303549f, 0.329961f, 0.897914f, 0.491784f, 0.131116f, 0.248425f, 0.276795f, 0.123547f, 0.463044f, 0.916051f, 0.668783f, 0.072474f, 0.005495f, 0.276248f, 0.362693f, 0.776750f, 0.967006f, 0.387567f, 0.686690f, 0.994902f, 0.745667f, 0.636190f, 0.078075f, 0.323215f, 0.913392f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.201005f, 0.843590f, 0.696324f, 0.366324f, 0.529174f, 0.542806f, 0.714054f, 0.516556f, 0.133076f, 0.773455f, 0.406273f, 0.963094f, 0.283514f, 0.263079f, 0.333507f, 0.572317f, 0.894870f, 0.176282f, 0.279679f, 0.581680f, 0.454334f, 0.447323f, 0.820734f, 0.923878f, 0.481307f, 0.687352f, 0.801059f, 0.518366f, 0.294316f, 0.638085f, 0.585109f, 0.901563f, 0.052407f, 0.910131f, 0.534432f, 0.015676f, 0.344702f, 0.724334f, 0.488433f, 0.980159f, 0.422610f, 0.326635f, 0.821672f, 0.547907f, 0.682327f, 0.805702f, 0.671428f, 0.422408f, 0.124796f, 0.580248f, 0.897433f, 0.418892f, 0.910725f, 0.503528f, 0.620842f, 0.832989f, 0.564597f, 0.090969f, 0.980979f, 0.245849f, 0.710505f, 0.505113f, 0.478773f, 0.243941f, 0.722151f, 0.112788f, 0.990453f, 0.845374f, 0.534509f, 0.424553f, 0.286465f, 0.501591f, 0.879417f, 0.275006f, 0.500537f, 0.234550f, 0.337149f, 0.190261f, 0.990539f, 0.571497f, 0.732815f, 0.098250f, 0.366118f, 0.892640f, 0.084438f, 0.165483f, 0.625418f, 0.622789f, 0.838227f, 0.935493f, 0.141986f, 0.259374f, 0.427461f, 0.000903f, 0.069814f, 0.226491f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f}; + // {2, 3, 4, 18} + constexpr float inff = std::numeric_limits::infinity(); + std::vector qk_matmul = {2.137658f, 1.567682f, 1.582827f, 0.953936f, 0.636597f, 1.001645f, 1.885707f, 1.361086f, 1.495408f, 1.566455f, 1.459078f, 1.668413f, 0.904174f, -inff, -inff, -inff, -inff, -inff, 1.229267f, 0.591855f, 1.372683f, 0.964445f, 1.006092f, 1.046331f, 1.712052f, 1.060710f, 2.141520f, 1.917742f, 1.063752f, 0.892409f, 0.884336f, 0.881352f, -inff, -inff, -inff, -inff, 2.235662f, 1.742821f, 2.198921f, 1.079357f, 1.510221f, 1.812315f, 1.396341f, 1.864746f, 1.498768f, 2.115730f, 0.844762f, 1.323617f, 1.096593f, 1.033003f, 1.868677f, -inff, -inff, -inff, 1.429269f, 0.876355f, 0.928405f, 1.469794f, 0.649940f, 1.435654f, 1.452830f, 1.053687f, 1.338220f, 0.966775f, 1.237266f, 1.488850f, 1.438267f, 0.931250f, 1.633272f, 0.944889f, -inff, -inff, 1.172613f, 1.105815f, 1.263303f, 1.702161f, 1.406517f, 1.808470f, 1.496128f, 1.169961f, 1.428707f, 1.393064f, 1.624670f, 1.287919f, 0.674733f, -inff, -inff, -inff, -inff, -inff, 0.838456f, 1.191558f, 1.771291f, 1.491907f, 0.911088f, 0.865799f, 1.154893f, 1.472593f, 0.826140f, 0.896018f, 1.281853f, 0.942941f, 1.470656f, 0.816028f, -inff, -inff, -inff, -inff, 1.133820f, 1.086309f, 1.712385f, 1.254675f, 1.427773f, 0.748848f, 1.056134f, 1.187805f, 1.419181f, 1.140224f, 1.269629f, 1.135934f, 0.694738f, 1.528325f, 0.959286f, -inff, -inff, -inff, 1.160321f, 1.097000f, 1.485019f, 1.111147f, 0.836961f, 0.948765f, 1.234762f, 0.835082f, 0.833382f, 0.589928f, 1.266538f, 1.303439f, 0.622733f, 0.837537f, 0.605730f, 0.730216f, -inff, -inff, 2.078597f, 0.610472f, 1.371772f, 0.794857f, 1.018924f, 1.165257f, 1.466839f, 1.206415f, 1.662507f, 1.098436f, 1.283408f, 1.533854f, 1.247966f, -inff, -inff, -inff, -inff, -inff, 1.707491f, 0.439978f, 0.919238f, 0.297115f, 0.982817f, 1.370520f, 0.766707f, 0.938981f, 1.095468f, 1.442393f, 0.742909f, 0.529869f, 0.628822f, 1.353301f, -inff, -inff, -inff, -inff, 1.483284f, 1.334536f, 0.757364f, 1.243801f, 0.767143f, 0.919318f, 0.693929f, 1.000990f, 1.107699f, 1.001247f, 1.434079f, 1.522769f, 0.696104f, 1.336034f, 0.501240f, -inff, -inff, -inff, 1.535892f, 1.342303f, 0.701559f, 1.211220f, 1.510985f, 0.961962f, 1.471503f, 1.440467f, 1.835586f, 0.947043f, 1.254547f, 1.009386f, 0.842613f, 1.508191f, 1.233544f, 1.280385f, -inff, -inff, 1.552432f, 0.958768f, 1.676495f, 1.810273f, 1.019336f, 1.487615f, 0.695035f, 1.391893f, 1.060641f, 0.917107f, 1.115109f, 1.128137f, 0.986429f, -inff, -inff, -inff, -inff, -inff, 1.289288f, 1.303667f, 0.882238f, 1.948027f, 1.580638f, 0.863439f, 1.059965f, 2.095325f, 1.493638f, 0.654104f, 0.828719f, 1.673449f, 0.479778f, 1.149678f, -inff, -inff, -inff, -inff, 1.177682f, 1.225590f, 1.735621f, 2.114078f, 1.905758f, 1.835981f, 1.432170f, 1.444457f, 2.016032f, 0.762211f, 1.059737f, 1.378216f, 1.564930f, 1.950097f, 1.598798f, -inff, -inff, -inff, 0.820477f, 0.962096f, 1.188223f, 1.264395f, 1.676953f, 1.487113f, 0.962162f, 1.377522f, 1.370079f, 1.450785f, 1.131087f, 1.962317f, 0.764849f, 0.777860f, 1.194763f, 1.030136f, -inff, -inff, 1.096708f, 1.345589f, 1.404595f, 1.370459f, 1.263369f, 1.364863f, 0.489623f, 0.596189f, 1.079480f, 0.915348f, 0.770954f, 1.548047f, 1.519504f, -inff, -inff, -inff, -inff, -inff, 1.856943f, 0.790590f, 1.235241f, 2.061177f, 1.282346f, 1.896653f, 1.112410f, 1.622862f, 0.780625f, 1.990919f, 1.693934f, 1.466544f, 1.026297f, 1.323339f, -inff, -inff, -inff, -inff, 1.778816f, 1.746915f, 1.169870f, 1.847628f, 0.729303f, 2.421048f, 1.266061f, 1.481203f, 1.016384f, 2.038725f, 1.132054f, 1.669076f, 1.958931f, 1.654780f, 1.644111f, -inff, -inff, -inff, 0.856287f, 1.124803f, 1.216201f, 0.831110f, 0.761234f, 1.204141f, 0.994307f, 0.832859f, 1.294077f, 1.566637f, 1.102631f, 1.472731f, 1.569911f, 0.779225f, 1.536189f, 1.277889f, -inff, -inff, 0.944230f, 1.585174f, 1.001532f, 0.973579f, 1.652668f, 1.112330f, 1.052878f, 1.326390f, 1.526319f, 1.790060f, 1.219317f, 1.742865f, 0.871467f, -inff, -inff, -inff, -inff, -inff, 0.794245f, 1.084904f, 0.813691f, 1.037344f, 0.254175f, 1.071614f, 0.477497f, 0.773591f, 1.317670f, 1.382451f, 0.759806f, 1.228428f, 0.583565f, 1.274037f, -inff, -inff, -inff, -inff, 0.865060f, 0.697643f, 1.300273f, 1.064195f, 1.435744f, 1.516307f, 0.626589f, 1.255387f, 1.115037f, 1.202643f, 1.789729f, 1.328769f, 1.046150f, 1.149905f, 1.696396f, -inff, -inff, -inff, 1.421552f, 1.324626f, 1.029005f, 0.960238f, 1.215132f, 1.450928f, 1.351898f, 1.718175f, 1.502146f, 1.736591f, 1.019685f, 1.130950f, 1.097223f, 1.330517f, 1.675029f, 1.069868f, -inff, -inff}; ASSERT_EQ(y.size(), batch_size * q_num_heads * q_sequence_length * v_head_size); ASSERT_EQ(present_key.size(), batch_size * kv_num_heads * (past_sequence_length + kv_sequence_length) * head_size); ASSERT_EQ(present_value.size(), batch_size * kv_num_heads * (past_sequence_length + kv_sequence_length) * v_head_size); @@ -1196,7 +1230,7 @@ TEST(AttentionTest, Attention4DWithMask4DPastAndPresentQkMatmul) { RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, q, k, v, m, std::initializer_list(), past_key, past_value, - -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + 1, 1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, present_key, present_value, qk_matmul, false, true, true // disable_cpu, disable_cuda, disable_dml ); diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_ep_static_kernels.cc b/onnxruntime/test/providers/internal_testing/internal_testing_ep_static_kernels.cc index ea6b3f148979f..b0ee078335308 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_ep_static_kernels.cc +++ b/onnxruntime/test/providers/internal_testing/internal_testing_ep_static_kernels.cc @@ -7,14 +7,12 @@ #include "core/framework/utils.h" #include "core/providers/cpu/nn/conv_attributes.h" #include "core/providers/utils.h" +#include "test/providers/internal_testing/internal_testing_execution_provider.h" namespace onnxruntime { namespace internal_testing_ep { -// can't use 'utils::kInternalTestingExecutionProvider' in the macro so redefine here to a name without '::' -constexpr const char* internal_testing_ep = utils::kInternalTestingExecutionProvider; - -ONNX_OPERATOR_KERNEL_EX(Conv, kMSInternalNHWCDomain, 11, internal_testing_ep, +ONNX_OPERATOR_KERNEL_EX(Conv, kMSInternalNHWCDomain, 11, kInternalTestingExecutionProvider, KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), Conv); diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc index 390329c5cae7a..934916ea862e9 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc +++ b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc @@ -26,17 +26,15 @@ namespace internal_testing_ep { // NHWC Conv requires contrib ops #if !defined(ORT_MINIMAL_BUILD) && !defined(DISABLE_CONTRIB_OPS) -// the 'utils::' breaks the kernel registration macros -constexpr const char* internal_testing_ep = utils::kInternalTestingExecutionProvider; -class ONNX_OPERATOR_KERNEL_CLASS_NAME(internal_testing_ep, kMSInternalNHWCDomain, 11, Conv); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kInternalTestingExecutionProvider, kMSInternalNHWCDomain, 11, Conv); // register static kernels we have implementations for static std::unique_ptr RegisterKernels() { auto kernel_registry = std::make_unique(); ORT_THROW_IF_ERROR(kernel_registry->Register( - BuildKernelCreateInfo())); return kernel_registry; @@ -68,7 +66,7 @@ void RegisterDummyStaticKernel(KernelRegistry& registry, const Node& node) { builder.SetName(node.OpType()) .SetDomain(node.Domain()) .SinceVersion(node.SinceVersion()) - .Provider(internal_testing_ep); + .Provider(kInternalTestingExecutionProvider); ORT_THROW_IF_ERROR(registry.Register(builder, DummyCreateKernel)); } @@ -85,7 +83,7 @@ constexpr const char* INTERNAL_TESTING_EP = "InternalTestingEP"; InternalTestingExecutionProvider::InternalTestingExecutionProvider(const std::unordered_set& ops, const std::unordered_set& stop_ops, DataLayout preferred_layout) - : IExecutionProvider{utils::kInternalTestingExecutionProvider}, + : IExecutionProvider{kInternalTestingExecutionProvider}, ep_name_{INTERNAL_TESTING_EP}, ops_{ops}, stop_ops_{stop_ops}, @@ -221,7 +219,7 @@ InternalTestingExecutionProvider::GetCapability(const onnxruntime::GraphViewer& auto compile_capabilities = utils::CreateSupportedPartitions(graph_viewer, supported_compiled_nodes, stop_ops_, generate_metadef_name, ep_name_, - onnxruntime::utils::kInternalTestingExecutionProvider, + kInternalTestingExecutionProvider, /*QDQ NodeUnit map*/ nullptr, debug_output_); diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h index 0caa0febc2796..8832265798798 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h +++ b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h @@ -9,6 +9,9 @@ namespace onnxruntime { namespace internal_testing_ep { +// Provider type of `InternalTestingExecutionProvider`, an EP used for internal testing. +constexpr const char* kInternalTestingExecutionProvider = "InternalTestingExecutionProvider"; + class InternalTestingExecutionProvider : public IExecutionProvider { public: InternalTestingExecutionProvider(const std::unordered_set& ops, diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_partitioning_tests.cc b/onnxruntime/test/providers/internal_testing/internal_testing_partitioning_tests.cc index d58db5178032d..c085d1acd10c0 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_partitioning_tests.cc +++ b/onnxruntime/test/providers/internal_testing/internal_testing_partitioning_tests.cc @@ -57,7 +57,7 @@ auto RunTest(const std::string& op, const ORTCHAR_T* model_path) { for (const auto& node : graph.Nodes()) { EXPECT_EQ(supported_ops.count(node.OpType()), size_t(0)) << "Nodes with supported op types should have been replaced. Node with type " << node.OpType() << " was not."; - if (node.GetExecutionProviderType() == utils::kInternalTestingExecutionProvider) { + if (node.GetExecutionProviderType() == kInternalTestingExecutionProvider) { EXPECT_STATUS_OK(func_mgr.GetFuncs(node.Name(), compute_func)); EXPECT_NE(compute_func, nullptr); ++num_partitions; @@ -116,7 +116,7 @@ TEST(InternalTestingEP, TestDependenciesCorrectlyHandled) { for (const auto& node : graph.Nodes()) { EXPECT_EQ(supported_ops.count(node.OpType()), size_t(0)) << "Nodes with supported op types should have been replaced. Node with type " << node.OpType() << " was not."; - if (node.GetExecutionProviderType() == utils::kInternalTestingExecutionProvider) { + if (node.GetExecutionProviderType() == kInternalTestingExecutionProvider) { EXPECT_STATUS_OK(func_mgr.GetFuncs(node.Name(), compute_func)); EXPECT_NE(compute_func, nullptr); ++num_partitions; @@ -227,7 +227,7 @@ static void TestNnapiPartitioning(const std::string& test_name, const std::strin std::string unsupported_op_str; for (const Node& node : graph.Nodes()) { - if (node.GetExecutionProviderType() != utils::kInternalTestingExecutionProvider && + if (node.GetExecutionProviderType() != kInternalTestingExecutionProvider && ops.count(node.OpType()) == 0) { auto entry = unsupported_ops.find(node.OpType()); if (entry != unsupported_ops.end()) { @@ -288,12 +288,12 @@ static void TestNnapiPartitioning(const std::string& test_name, const std::strin << "Nodes with supported op types should have been replaced. Node with type " << node.OpType() << " was not."; } else { - EXPECT_NE(node.GetExecutionProviderType(), utils::kInternalTestingExecutionProvider) + EXPECT_NE(node.GetExecutionProviderType(), kInternalTestingExecutionProvider) << "Node is downstream from a 'stop at' node and should not have been taken. Node:" << node.Name(); } - if (node.GetExecutionProviderType() == utils::kInternalTestingExecutionProvider) { + if (node.GetExecutionProviderType() == kInternalTestingExecutionProvider) { EXPECT_STATUS_OK(func_mgr.GetFuncs(node.Name(), compute_func)); EXPECT_NE(compute_func, nullptr); ++stats.num_compiled_nodes; diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc b/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc index 275f29fdd9073..94e60739c3ccf 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc +++ b/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc @@ -334,7 +334,7 @@ TEST(InternalTestingEP, TestLoadOrtModelWithReducedOpCoverage) { for (const auto& node : graph.Nodes()) { EXPECT_EQ(supported_ops.count(node.OpType()), size_t(0)) << "Nodes with supported op types should have been replaced. Node with type " << node.OpType() << " was not."; - if (node.GetExecutionProviderType() == utils::kInternalTestingExecutionProvider) { + if (node.GetExecutionProviderType() == kInternalTestingExecutionProvider) { EXPECT_STATUS_OK(func_mgr.GetFuncs(node.Name(), compute_func)); EXPECT_NE(compute_func, nullptr); EXPECT_THAT(node.OpType(), ::testing::StartsWith(expected_op_type_prefix)); @@ -353,7 +353,7 @@ static int CountAndValidateAssignedNodes(const Graph& current_graph, for (const auto& node : current_graph.Nodes()) { EXPECT_EQ(supported_ops.count(node.OpType()), size_t(0)) << "Nodes with supported op types should have been replaced. Node with type " << node.OpType() << " was not."; - if (node.GetExecutionProviderType() == utils::kInternalTestingExecutionProvider) { + if (node.GetExecutionProviderType() == kInternalTestingExecutionProvider) { const NodeComputeInfo* compute_func = nullptr; EXPECT_STATUS_OK(func_mgr.GetFuncs(node.Name(), compute_func)); EXPECT_NE(compute_func, nullptr); diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index 1820664e1d604..b85030b46e94d 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -689,15 +689,30 @@ def test_run_model_with_optional_sequence_input(self): def test_run_model(self): sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=available_providers) x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) - input_name = sess.get_inputs()[0].name - self.assertEqual(input_name, "X") - input_shape = sess.get_inputs()[0].shape - self.assertEqual(input_shape, [3, 2]) - output_name = sess.get_outputs()[0].name - self.assertEqual(output_name, "Y") - output_shape = sess.get_outputs()[0].shape - self.assertEqual(output_shape, [3, 2]) - res = sess.run([output_name], {input_name: x}) + + inputs = sess.get_inputs() + self.assertEqual(len(inputs), 1) + self.assertEqual(inputs[0].name, "X") + self.assertEqual(inputs[0].shape, [3, 2]) + + input_meminfos = sess.get_input_memory_infos() + self.assertEqual(len(input_meminfos), 1) + self.assertIsNotNone(input_meminfos[0]) + + input_epdevices = sess.get_input_epdevices() + # The entry my be None (null) but it should be present + self.assertEqual(len(input_epdevices), 1) + + outputs = sess.get_outputs() + self.assertEqual(len(outputs), 1) + self.assertEqual(outputs[0].name, "Y") + self.assertEqual(outputs[0].shape, [3, 2]) + + output_meminfos = sess.get_output_memory_infos() + self.assertEqual(len(output_meminfos), 1) + self.assertIsNotNone(output_meminfos[0]) + + res = sess.run([outputs[0].name], {inputs[0].name: x}) output_expected = np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32) np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) @@ -1584,6 +1599,44 @@ def test_run_model_with_cuda_copy_stream(self): for _iteration in range(100000): session.run(output_names=["output"], input_feed={"shape": shape}) + def test_ort_device(self): + cpu_device = onnxrt.OrtDevice.make("cpu", 0) + self.assertEqual(cpu_device.device_id(), 0) + self.assertEqual(cpu_device.device_type(), 0) + self.assertEqual(cpu_device.device_vendor_id(), 0) + self.assertEqual(cpu_device.device_mem_type(), 0) + + def test_ort_memory_info(self): + cpu_memory_info = onnxrt.OrtMemoryInfo( + "Cpu", + onnxrt.OrtAllocatorType.ORT_ARENA_ALLOCATOR, + 0, + onnxrt.OrtMemType.DEFAULT, + ) + self.assertEqual(cpu_memory_info.name, "Cpu") + self.assertEqual(cpu_memory_info.device_id, 0) + self.assertEqual(cpu_memory_info.mem_type, onnxrt.OrtMemType.DEFAULT) + self.assertEqual(cpu_memory_info.allocator_type, onnxrt.OrtAllocatorType.ORT_ARENA_ALLOCATOR) + self.assertEqual(cpu_memory_info.device_mem_type, onnxrt.OrtDeviceMemoryType.DEFAULT) + self.assertEqual(cpu_memory_info.device_vendor_id, 0) + + def test_ort_memory_info_create_v2(self): + cpu_memory_info = onnxrt.OrtMemoryInfo.create_v2( + "Test", + onnxrt.OrtMemoryInfoDeviceType.CPU, + 0, # vendor_id + 0, # device_id + onnxrt.OrtDeviceMemoryType.DEFAULT, + 128, # alignment + onnxrt.OrtAllocatorType.ORT_ARENA_ALLOCATOR, + ) + self.assertEqual(cpu_memory_info.name, "Test") + self.assertEqual(cpu_memory_info.device_id, 0) + self.assertEqual(cpu_memory_info.mem_type, onnxrt.OrtMemType.DEFAULT) + self.assertEqual(cpu_memory_info.allocator_type, onnxrt.OrtAllocatorType.ORT_ARENA_ALLOCATOR) + self.assertEqual(cpu_memory_info.device_mem_type, onnxrt.OrtDeviceMemoryType.DEFAULT) + self.assertEqual(cpu_memory_info.device_vendor_id, 0) + def test_shared_allocator_using_create_and_register_allocator(self): # Create and register an arena based allocator diff --git a/onnxruntime/test/python/onnxruntime_test_python_autoep.py b/onnxruntime/test/python/onnxruntime_test_python_autoep.py index cb31627a87c48..d6281d165c053 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_autoep.py +++ b/onnxruntime/test/python/onnxruntime_test_python_autoep.py @@ -226,6 +226,14 @@ def test_example_plugin_ep_devices(self): hw_metadata = hw_device.metadata self.assertGreater(len(hw_metadata), 0) # Should have at least SPDRP_HARDWAREID on Windows + test_mem_info = test_ep_device.memory_info(onnxrt.OrtDeviceMemoryType.DEFAULT) + self.assertIsNotNone(test_mem_info) + del test_mem_info + + test_sync_stream = test_ep_device.create_sync_stream() + self.assertIsNotNone(test_sync_stream) + del test_sync_stream + # Add EP plugin's OrtEpDevice to the SessionOptions. sess_options = onnxrt.SessionOptions() sess_options.add_provider_for_devices([test_ep_device], {"opt1": "val1"}) @@ -282,6 +290,55 @@ def test_example_plugin_ep_data_transfer(self): self.unregister_execution_provider_library(ep_name) + def test_copy_tensors(self): + """ + Test global api copy_tensors between OrtValue objects + using EP plug-in data transfer + """ + if sys.platform != "win32": + self.skipTest("Skipping test because device discovery is only supported on Windows") + + ep_lib_path = "example_plugin_ep.dll" + try: + ep_lib_path = get_name("example_plugin_ep.dll") + except FileNotFoundError: + self.skipTest(f"Skipping test because EP library '{ep_lib_path}' cannot be found") + + ep_name = "example_ep" + self.register_execution_provider_library(ep_name, os.path.realpath(ep_lib_path)) + + # Generate 2 numpy arrays + a = np.random.rand(3, 2).astype(np.float32) + b = np.random.rand(3, 2).astype(np.float32) + + # Create OrtValue from numpy arrays on EP device + # the example EP pretends to use GPU memory, so we place it there + a_device = onnxrt.OrtValue.ortvalue_from_numpy(a, "gpu", 0, 0xBE57) + b_device = onnxrt.OrtValue.ortvalue_from_numpy(b, "gpu", 0, 0xBE57) + + # Create destination ort values with the same shape on CPU + a_cpu_copy = onnxrt.OrtValue.ortvalue_from_shape_and_type(a.shape, a.dtype) + b_cpu_copy = onnxrt.OrtValue.ortvalue_from_shape_and_type(b.shape, b.dtype) + + # source list + src_list = [a_device, b_device] + dst_list = [a_cpu_copy, b_cpu_copy] + # Passing None for stream as we copy between CPU + # Test None because it is allowed + onnxrt.copy_tensors(src_list, dst_list, None) + + # Release the OrtValue on the EP device + # before the EP library is unregistered + del src_list + del a_device + del b_device + + # Verify the contents + np.testing.assert_array_equal(a, a_cpu_copy.numpy()) + np.testing.assert_array_equal(b, b_cpu_copy.numpy()) + + self.unregister_execution_provider_library(ep_name) + if __name__ == "__main__": unittest.main(verbosity=1) diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index b7a9da8e1b658..5199730ae323d 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -494,6 +494,35 @@ INSTANTIATE_TEST_SUITE_P(CApiTestWithProviders, CApiTestWithProvider, ::testing::Values(0, 1, 2, 3, 4)); +TEST(CApiTest, DISABLED_TestInputPassThroughToOutput) { + const ORTCHAR_T* model_uri = TSTR("testdata/input_propagated_to_output.onnx"); + Ort::Session session(*ort_env, model_uri, Ort::SessionOptions{}); + auto inputs_meminfos = session.GetMemoryInfoForInputs(); + ASSERT_EQ(1U, inputs_meminfos.size()); + auto inputs_epdevices = session.GetEpDeviceForInputs(); + ASSERT_EQ(1U, inputs_epdevices.size()); + auto outputs_meminfos = session.GetMemoryInfoForOutputs(); + ASSERT_EQ(7U, outputs_meminfos.size()); +} + +TEST(CApiTest, DISABLED_TestDanglingInput) { + // Here we test an issue with segments_ids that is an input not consumed by anything + // This kind of model is unlikely to be used in practice but we want to make sure it works + const ORTCHAR_T* model_uri = TSTR("testdata/test_dangling_input_segment_ids.onnx"); + Ort::Session session(*ort_env, model_uri, Ort::SessionOptions{}); + auto inputs_meminfos = session.GetMemoryInfoForInputs(); + ASSERT_EQ(2U, inputs_meminfos.size()); + auto outputs_meminfos = session.GetMemoryInfoForOutputs(); + ASSERT_EQ(2U, outputs_meminfos.size()); + auto inputs_epdevices = session.GetEpDeviceForInputs(); + ASSERT_EQ(2U, inputs_epdevices.size()); + // One of the devices returning is null since the input is not consumed + // there is not a device for it. + const bool null_present = std::any_of(inputs_epdevices.begin(), inputs_epdevices.end(), + [](const auto& device) { return device == nullptr; }); + ASSERT_TRUE(null_present); +} + #if !defined(DISABLE_SPARSE_TENSORS) TEST(CApiTest, SparseOutputModel) { std::vector dense_shape{3, 3}; @@ -505,7 +534,15 @@ TEST(CApiTest, SparseOutputModel) { std::vector ort_inputs; std::vector input_names; const char* const output_names[] = {"values"}; + // This model produces a sparse output from a constant sparse initializer Ort::Session session(*ort_env, SPARSE_OUTPUT_MODEL_URI, Ort::SessionOptions{}); + auto inputs_meminfos = session.GetMemoryInfoForInputs(); + ASSERT_TRUE(inputs_meminfos.empty()); + auto outputs_meminfos = session.GetMemoryInfoForOutputs(); + ASSERT_EQ(1U, outputs_meminfos.size()); + auto inputs_epdevices = session.GetEpDeviceForInputs(); + ASSERT_TRUE(inputs_epdevices.empty()); + auto ort_outputs = session.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(), output_names, 1); ASSERT_EQ(ort_outputs.size(), 1U); diff --git a/onnxruntime/test/testdata/add_mul_add.onnx b/onnxruntime/test/testdata/add_mul_add.onnx new file mode 100644 index 0000000000000..0e2bc1bb9cff9 --- /dev/null +++ b/onnxruntime/test/testdata/add_mul_add.onnx @@ -0,0 +1,28 @@ + +:´ + +A +B +add_outputadd_0"Add +' + +add_output +B +mul_outputmul_0"Mul + + +mul_output +ACadd_1"Add +Main_graphZ +A +  + +Z +B +  + +b +C +  + +B \ No newline at end of file diff --git a/onnxruntime/test/testdata/add_mul_add.py b/onnxruntime/test/testdata/add_mul_add.py new file mode 100644 index 0000000000000..c22a176e065dd --- /dev/null +++ b/onnxruntime/test/testdata/add_mul_add.py @@ -0,0 +1,37 @@ +from onnx import TensorProto, checker, helper, save + +# (A + B) * B + A +graph_proto = helper.make_graph( + nodes=[ + helper.make_node( + "Add", + inputs=["A", "B"], + outputs=["add_output"], + name="add_0", + ), + helper.make_node( + "Mul", + inputs=["add_output", "B"], + outputs=["mul_output"], + name="mul_0", + ), + helper.make_node( + "Add", + inputs=["mul_output", "A"], + outputs=["C"], + name="add_1", + ), + ], + name="Main_graph", + inputs=[ + helper.make_tensor_value_info("A", TensorProto.FLOAT, [3, 2]), + helper.make_tensor_value_info("B", TensorProto.FLOAT, [3, 2]), + ], + outputs=[ + helper.make_tensor_value_info("C", TensorProto.FLOAT, [3, 2]), + ], +) + +model = helper.make_model(graph_proto) +checker.check_model(model, True) +save(model, "add_mul_add.onnx") diff --git a/onnxruntime/test/testdata/input_propagated_to_output.onnx b/onnxruntime/test/testdata/input_propagated_to_output.onnx new file mode 100644 index 0000000000000..feeab10556cb0 Binary files /dev/null and b/onnxruntime/test/testdata/input_propagated_to_output.onnx differ diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index 43f6e480672ba..a04aafecbc81a 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -32,6 +32,31 @@ "^test_adagrad", "^test_adagrad_multiple", "^test_attention_3d.*", // wrong expected values in onnx==1.18.0, fixed in 1.19.0 + "^test_attention_4d_diff_heads_mask4d_padded_kv*", // pending onnx update + "^test_attention_3d_gqa*", // pending onnx update + "^test_attention_3d_gqa_causal", // pending onnx update + "^test_attention_3d_gqa_scaled", // pending onnx update + "^test_attention_3d_gqa_softcap", // pending onnx update + "^test_attention_3d_gqa_with_past_and_present", // pending onnx update + "^test_attention_4d_gqa*", // pending onnx update + "^test_attention_4d_gqa_causal", // pending onnx update + "^test_attention_4d_gqa_scaled", // pending onnx update + "^test_attention_4d_gqa_softcap", // pending onnx update + "^test_attention_4d_gqa_with_past_and_present", // pending onnx update + "^test_attention_*causal*", // pending onnx update + "^test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_causal*", // pending onnx update + "^test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal*", // pending onnx update + "^test_attention_4d_attn_mask_3d_causal_expanded*", // pending onnx update + "^test_attention_4d_fp16*", // precision issue: 1 / 192 mismatched elements + "^test_attention_4d_fp16_expanded*", // precision issue: 3 / 192 mismatched elements + "^test_l2normalization*", // LpNormalization(22) not implemented + "^test_l1normalization*", // LpNormalization(22) not implemented + "^test_lpnormalization*", // LpNormalization(22) not implemented + "^test_tensorscatter*", // TensorScatter(24) not implemented + "^test_castlike_no_saturate_FLOAT_to_FLOAT8*", // ORT does not support ml_dtypes + "^test_castlike_UINT4_to*", // ORT does not support ml_dtypes + "^test_castlike_INT4_to*", // ORT does not support ml_dtypes + "^test_cast_e8m0_*", // ORT does not support float8e8m0 "^test_batchnorm_epsilon_training_mode", "^test_batchnorm_example_training_mode", "^test_col2im_pads", // still one wrong value coming from the backtest example diff --git a/onnxruntime/test/testdata/test_dangling_input_segment_ids.onnx b/onnxruntime/test/testdata/test_dangling_input_segment_ids.onnx new file mode 100644 index 0000000000000..a83c21030ad67 Binary files /dev/null and b/onnxruntime/test/testdata/test_dangling_input_segment_ids.onnx differ diff --git a/onnxruntime/test/testdata/test_dangling_input_segment_ids.py b/onnxruntime/test/testdata/test_dangling_input_segment_ids.py new file mode 100644 index 0000000000000..c5eb8a600d6b5 --- /dev/null +++ b/onnxruntime/test/testdata/test_dangling_input_segment_ids.py @@ -0,0 +1,86 @@ +""" +Run this script to recreate the original onnx model. +Example usage: +python test_dangling_input_segment_ids.py out_model_path.onnx +""" + +import os +import sys + +import numpy as np +import onnx +from onnx import TensorProto, helper, numpy_helper + +DATA_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "test_dangling_input_segment_ids") + + +def order_repeated_field(repeated_proto, key_name, order): + order = list(order) + repeated_proto.sort(key=lambda x: order.index(getattr(x, key_name))) + + +def make_node(op_type, inputs, outputs, name=None, doc_string=None, domain=None, **kwargs): + node = helper.make_node(op_type, inputs, outputs, name, doc_string, domain, **kwargs) + if doc_string == "": + node.doc_string = "" + order_repeated_field(node.attribute, "name", kwargs.keys()) + return node + + +def make_graph(*args, doc_string=None, **kwargs): + graph = helper.make_graph(*args, doc_string=doc_string, **kwargs) + if doc_string == "": + graph.doc_string = "" + return graph + + +model = helper.make_model( + opset_imports=[helper.make_operatorsetid("", 14), helper.make_operatorsetid("com.microsoft", 1)], + ir_version=7, + graph=make_graph( + name="embed_layernorm_graph", + inputs=[ + helper.make_tensor_value_info("input_ids", TensorProto.INT32, shape=[1, 4]), + helper.make_tensor_value_info("segment_ids", TensorProto.INT32, shape=[1, 4]), + ], + outputs=[ + helper.make_tensor_value_info("layernorm_out", TensorProto.FLOAT, shape=[1, 4, 4]), + helper.make_tensor_value_info("mask_index_out", TensorProto.INT32, shape=[1]), + ], + initializer=[ + numpy_helper.from_array( + np.load(os.path.join(DATA_DIR, "const0_word_embed.npy")).astype("float32").reshape([32, 4]), + name="word_embed", + ), + numpy_helper.from_array( + np.load(os.path.join(DATA_DIR, "const1_pos_embed.npy")).astype("float32").reshape([16, 4]), + name="pos_embed", + ), + numpy_helper.from_array( + np.array( + [0.6185135841369629, 0.010364261455833912, 0.5386272668838501, 0.0030179566238075495], + dtype="float32", + ), + name="gamma", + ), + numpy_helper.from_array( + np.array( + [0.9511938095092773, 0.9054020047187805, 0.7959669232368469, 0.9152743220329285], dtype="float32" + ), + name="beta", + ), + ], + nodes=[ + make_node( + "EmbedLayerNormalization", + inputs=["input_ids", "", "word_embed", "pos_embed", "", "gamma", "beta"], + outputs=["layernorm_out", "mask_index_out"], + domain="com.microsoft", + ) + ], + ), +) + +if __name__ == "__main__" and len(sys.argv) == 2: + _, out_path = sys.argv + onnx.save(model, out_path)