diff --git a/VERSION_NUMBER b/VERSION_NUMBER
index a6c2798a482eb..49e0a31d4964d 100644
--- a/VERSION_NUMBER
+++ b/VERSION_NUMBER
@@ -1 +1 @@
-1.23.0
+1.23.1
diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake
index 6aad71e40b2a8..b23365e99c2d7 100644
--- a/cmake/onnxruntime_unittests.cmake
+++ b/cmake/onnxruntime_unittests.cmake
@@ -1800,6 +1800,7 @@ endif()
if (WIN32 AND onnxruntime_BUILD_SHARED_LIB AND
NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND
NOT onnxruntime_MINIMAL_BUILD)
+ # example_plugin_ep
file(GLOB onnxruntime_autoep_test_library_src "${TEST_SRC_DIR}/autoep/library/*.h"
"${TEST_SRC_DIR}/autoep/library/*.cc")
onnxruntime_add_shared_library_module(example_plugin_ep ${onnxruntime_autoep_test_library_src})
@@ -1822,6 +1823,9 @@ if (WIN32 AND onnxruntime_BUILD_SHARED_LIB AND
set_property(TARGET example_plugin_ep APPEND_STRING PROPERTY LINK_FLAGS
${ONNXRUNTIME_AUTOEP_LIB_LINK_FLAG})
+ set_target_properties(example_plugin_ep PROPERTIES FOLDER "ONNXRuntimeTest")
+ source_group(TREE ${TEST_SRC_DIR} FILES ${onnxruntime_autoep_test_library_src})
+
# test library
file(GLOB onnxruntime_autoep_test_SRC "${ONNXRUNTIME_AUTOEP_TEST_SRC_DIR}/*.h"
"${ONNXRUNTIME_AUTOEP_TEST_SRC_DIR}/*.cc")
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 660c63d056335..b8385a3352278 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -253,6 +253,8 @@ Do not modify directly.*
|||[9, 12]|**T** = tensor(float)|
|||[1, 8]|**T** = tensor(float)|
|MelWeightMatrix|*in* num_mel_bins:**T1**
*in* dft_length:**T1**
*in* sample_rate:**T1**
*in* lower_edge_hertz:**T2**
*in* upper_edge_hertz:**T2**
*out* output:**T3**|17+|**T1** = tensor(int32), tensor(int64)
**T2** = tensor(float)
**T3** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|MemcpyFromHost|*in* X:**T**
*out* Y:**T**|1+|**T** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|MemcpyToHost|*in* X:**T**
*out* Y:**T**|1+|**T** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Min|*in* data_0:**T**
*out* min:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|||[8, 11]|**T** = tensor(double), tensor(float)|
diff --git a/docs/python/README.rst b/docs/python/README.rst
index fdef200c1d0de..c23c194ed8132 100644
--- a/docs/python/README.rst
+++ b/docs/python/README.rst
@@ -8,6 +8,11 @@ For more information on ONNX Runtime, please see `aka.ms/onnxruntime >;
/** \brief Wrapper around ::OrtSyncStream
*
*/
-struct SyncStream : detail::Base {
- explicit SyncStream(std::nullptr_t) {} ///< Create an empty SyncStream object, must be assigned a valid one to be used
- explicit SyncStream(OrtSyncStream* p) : Base{p} {} ///< Take ownership of a pointer created by C API
- void* GetHandle() const; ///< Wraps SyncStream_GetHandle
+
+namespace detail {
+template
+struct SyncStreamImpl : Base {
+ using B = Base;
+ using B::B;
+ // For some reason this is not a const method on the stream
+ void* GetHandle(); ///< Wraps SyncStream_GetHandle
};
+} // namespace detail
+
+struct SyncStream : detail::SyncStreamImpl {
+ ///< Create an empty SyncStream object, must be assigned a valid one to be used
+ explicit SyncStream(std::nullptr_t) {}
+ ///< Take ownership of a pointer created by C API
+ explicit SyncStream(OrtSyncStream* p) : SyncStreamImpl{p} {}
+};
+
+using UnownedSyncStream = detail::SyncStreamImpl>;
namespace detail {
template
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
index 59979189eed0f..cb6448ad12a81 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
@@ -669,9 +669,12 @@ inline void KeyValuePairs::Remove(const char* key) {
GetApi().RemoveKeyValuePair(this->p_, key);
}
-inline void* SyncStream::GetHandle() const {
+namespace detail {
+template
+inline void* SyncStreamImpl::GetHandle() {
return GetApi().SyncStream_GetHandle(this->p_);
}
+} // namespace detail
namespace detail {
template
@@ -1582,11 +1585,13 @@ inline std::vector ConstSessionImpl::GetMemoryInfoForInputs(
auto num_inputs = GetInputCount();
std::vector mem_infos;
- mem_infos.resize(num_inputs);
+ if (num_inputs > 0) {
+ mem_infos.resize(num_inputs);
- ThrowOnError(GetApi().SessionGetMemoryInfoForInputs(this->p_,
- reinterpret_cast(mem_infos.data()),
- num_inputs));
+ ThrowOnError(GetApi().SessionGetMemoryInfoForInputs(this->p_,
+ reinterpret_cast(mem_infos.data()),
+ num_inputs));
+ }
return mem_infos;
}
@@ -1598,11 +1603,13 @@ inline std::vector ConstSessionImpl::GetMemoryInfoForOutputs
auto num_outputs = GetOutputCount();
std::vector mem_infos;
- mem_infos.resize(num_outputs);
+ if (num_outputs > 0) {
+ mem_infos.resize(num_outputs);
- ThrowOnError(GetApi().SessionGetMemoryInfoForOutputs(this->p_,
- reinterpret_cast(mem_infos.data()),
- num_outputs));
+ ThrowOnError(GetApi().SessionGetMemoryInfoForOutputs(this->p_,
+ reinterpret_cast(mem_infos.data()),
+ num_outputs));
+ }
return mem_infos;
}
@@ -1631,12 +1638,12 @@ template
inline std::vector ConstSessionImpl::GetEpDeviceForInputs() const {
auto num_inputs = GetInputCount();
std::vector input_devices;
- input_devices.resize(num_inputs);
-
- ThrowOnError(GetApi().SessionGetEpDeviceForInputs(this->p_,
- reinterpret_cast(input_devices.data()),
- num_inputs));
-
+ if (num_inputs > 0) {
+ input_devices.resize(num_inputs);
+ ThrowOnError(GetApi().SessionGetEpDeviceForInputs(this->p_,
+ reinterpret_cast(input_devices.data()),
+ num_inputs));
+ }
return input_devices;
}
diff --git a/js/common/lib/version.ts b/js/common/lib/version.ts
index 994eb6f4300c1..f00f4ec8dee50 100644
--- a/js/common/lib/version.ts
+++ b/js/common/lib/version.ts
@@ -4,4 +4,4 @@
// This file is generated by /js/scripts/update-version.ts
// Do not modify file content manually.
-export const version = '1.23.0';
+export const version = '1.23.1';
diff --git a/js/common/package-lock.json b/js/common/package-lock.json
index 706f8b46a3ad4..9ef468a229788 100644
--- a/js/common/package-lock.json
+++ b/js/common/package-lock.json
@@ -1,12 +1,12 @@
{
"name": "onnxruntime-common",
- "version": "1.23.0",
+ "version": "1.23.1",
"lockfileVersion": 2,
"requires": true,
"packages": {
"": {
"name": "onnxruntime-common",
- "version": "1.23.0",
+ "version": "1.23.1",
"license": "MIT",
"devDependencies": {
"typedoc": "^0.25.7"
diff --git a/js/common/package.json b/js/common/package.json
index a0eff9095e6d7..200aff42f8fca 100644
--- a/js/common/package.json
+++ b/js/common/package.json
@@ -2,7 +2,7 @@
"license": "MIT",
"type": "module",
"name": "onnxruntime-common",
- "version": "1.23.0",
+ "version": "1.23.1",
"repository": {
"url": "https://github.com/Microsoft/onnxruntime.git",
"type": "git"
diff --git a/js/node/lib/version.ts b/js/node/lib/version.ts
index 994eb6f4300c1..f00f4ec8dee50 100644
--- a/js/node/lib/version.ts
+++ b/js/node/lib/version.ts
@@ -4,4 +4,4 @@
// This file is generated by /js/scripts/update-version.ts
// Do not modify file content manually.
-export const version = '1.23.0';
+export const version = '1.23.1';
diff --git a/js/node/package-lock.json b/js/node/package-lock.json
index bd7e6cc1966c7..0a65eab39df70 100644
--- a/js/node/package-lock.json
+++ b/js/node/package-lock.json
@@ -1,12 +1,12 @@
{
"name": "onnxruntime-node",
- "version": "1.23.0",
+ "version": "1.23.1",
"lockfileVersion": 2,
"requires": true,
"packages": {
"": {
"name": "onnxruntime-node",
- "version": "1.23.0",
+ "version": "1.23.1",
"hasInstallScript": true,
"license": "MIT",
"os": [
@@ -30,7 +30,7 @@
},
"../common": {
"name": "onnxruntime-common",
- "version": "1.23.0",
+ "version": "1.23.1",
"license": "MIT",
"devDependencies": {
"typedoc": "^0.25.7"
diff --git a/js/node/package.json b/js/node/package.json
index 5520a48aa124a..1f29f2354b0d7 100644
--- a/js/node/package.json
+++ b/js/node/package.json
@@ -11,7 +11,7 @@
6
]
},
- "version": "1.23.0",
+ "version": "1.23.1",
"dependencies": {
"adm-zip": "^0.5.16",
"global-agent": "^3.0.0",
diff --git a/js/node/script/install-metadata-versions.js b/js/node/script/install-metadata-versions.js
index 3147f90904e7a..23df0b7ac96ed 100644
--- a/js/node/script/install-metadata-versions.js
+++ b/js/node/script/install-metadata-versions.js
@@ -4,4 +4,4 @@
// This file is generated by /js/scripts/update-version.ts
// Do not modify file content manually.
-module.exports = { nuget: [{ feed: 'nuget', version: '1.23.0' }] };
+module.exports = { nuget: [{ feed: 'nuget', version: '1.23.1' }] };
diff --git a/js/react_native/lib/version.ts b/js/react_native/lib/version.ts
index 994eb6f4300c1..f00f4ec8dee50 100644
--- a/js/react_native/lib/version.ts
+++ b/js/react_native/lib/version.ts
@@ -4,4 +4,4 @@
// This file is generated by /js/scripts/update-version.ts
// Do not modify file content manually.
-export const version = '1.23.0';
+export const version = '1.23.1';
diff --git a/js/react_native/package-lock.json b/js/react_native/package-lock.json
index ec2147b2cc4ba..f681b9166da98 100644
--- a/js/react_native/package-lock.json
+++ b/js/react_native/package-lock.json
@@ -1,12 +1,12 @@
{
"name": "onnxruntime-react-native",
- "version": "1.23.0",
+ "version": "1.23.1",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "onnxruntime-react-native",
- "version": "1.23.0",
+ "version": "1.23.1",
"license": "MIT",
"dependencies": {
"buffer": "^6.0.3",
@@ -31,7 +31,7 @@
},
"../common": {
"name": "onnxruntime-common",
- "version": "1.23.0",
+ "version": "1.23.1",
"license": "MIT",
"devDependencies": {
"typedoc": "^0.25.7"
diff --git a/js/react_native/package.json b/js/react_native/package.json
index 7a5ee35bdb25a..a88f5cf267aed 100644
--- a/js/react_native/package.json
+++ b/js/react_native/package.json
@@ -37,7 +37,7 @@
"registry": "https://registry.npmjs.org/"
},
"source": "lib/index",
- "version": "1.23.0",
+ "version": "1.23.1",
"main": "dist/commonjs/index",
"homepage": "https://github.com/microsoft/onnxruntime/blob/main/js/react_native/README.md",
"files": [
diff --git a/js/web/lib/version.ts b/js/web/lib/version.ts
index 994eb6f4300c1..f00f4ec8dee50 100644
--- a/js/web/lib/version.ts
+++ b/js/web/lib/version.ts
@@ -4,4 +4,4 @@
// This file is generated by /js/scripts/update-version.ts
// Do not modify file content manually.
-export const version = '1.23.0';
+export const version = '1.23.1';
diff --git a/js/web/package-lock.json b/js/web/package-lock.json
index eabb198e97177..74776abb25bd5 100644
--- a/js/web/package-lock.json
+++ b/js/web/package-lock.json
@@ -1,12 +1,12 @@
{
"name": "onnxruntime-web",
- "version": "1.23.0",
+ "version": "1.23.1",
"lockfileVersion": 2,
"requires": true,
"packages": {
"": {
"name": "onnxruntime-web",
- "version": "1.23.0",
+ "version": "1.23.1",
"license": "MIT",
"dependencies": {
"flatbuffers": "^25.1.24",
@@ -50,7 +50,7 @@
},
"../common": {
"name": "onnxruntime-common",
- "version": "1.23.0",
+ "version": "1.23.1",
"license": "MIT",
"devDependencies": {
"typedoc": "^0.25.7"
diff --git a/js/web/package.json b/js/web/package.json
index 425aa88035424..db20202b4f24e 100644
--- a/js/web/package.json
+++ b/js/web/package.json
@@ -7,7 +7,7 @@
"type": "git"
},
"author": "fs-eire",
- "version": "1.23.0",
+ "version": "1.23.1",
"jsdelivr": "dist/ort.min.js",
"dependencies": {
"flatbuffers": "^25.1.24",
diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py
index 550502cf3bc48..ac25159802092 100644
--- a/onnxruntime/__init__.py
+++ b/onnxruntime/__init__.py
@@ -8,7 +8,7 @@
or the `Github project `_.
"""
-__version__ = "1.23.0"
+__version__ = "1.23.1"
__author__ = "Microsoft"
# we need to do device version validation (for example to check Cuda version for an onnxruntime-training package).
@@ -31,14 +31,17 @@
OrtAllocatorType, # noqa: F401
OrtArenaCfg, # noqa: F401
OrtCompileApiFlags, # noqa: F401
+ OrtDeviceMemoryType, # noqa: F401
OrtEpDevice, # noqa: F401
OrtExecutionProviderDevicePolicy, # noqa: F401
OrtExternalInitializerInfo, # noqa: F401
OrtHardwareDevice, # noqa: F401
OrtHardwareDeviceType, # noqa: F401
OrtMemoryInfo, # noqa: F401
+ OrtMemoryInfoDeviceType, # noqa: F401
OrtMemType, # noqa: F401
OrtSparseFormat, # noqa: F401
+ OrtSyncStream, # noqa: F401
RunOptions, # noqa: F401
SessionIOBinding, # noqa: F401
SessionOptions, # noqa: F401
@@ -78,6 +81,7 @@
OrtDevice, # noqa: F401
OrtValue, # noqa: F401
SparseTensor, # noqa: F401
+ copy_tensors, # noqa: F401
)
# TODO: thiagofc: Temporary experimental namespace for new PyTorch front-end
diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc
index 98cc2158eb0d0..01ba492eb166e 100644
--- a/onnxruntime/core/framework/session_state.cc
+++ b/onnxruntime/core/framework/session_state.cc
@@ -226,13 +226,22 @@ Status SessionState::PopulateKernelCreateInfo(const KernelRegistryManager& kerne
for (auto& node : graph_.Nodes()) {
const KernelCreateInfo* kci = nullptr;
auto status = kernel_registry_manager.SearchKernelRegistry(node, logger_, &kci);
- if (!status.IsOK() && saving_ort_format) {
- // if we didn't find the kernel and are saving to ORT format an EP that compiles nodes is enabled.
- // in that case we assigned the node to that EP but do not compile it into a fused node.
- // this keeps the original node and prevents level 2 and level 3 optimizers from modifying it.
- // we now revert to the CPU EP kernel as a fallback.
- // at runtime when the model is loaded in a minimal build, the compiling EP will replace this node if possible.
- // if that's not possible for some reason we can fallback to the CPU EP implementation.
+
+ // There are two cases where we allow fallback to CPU EP kernels:
+ //
+ // 1. if we didn't find the kernel and are saving to ORT format an EP that compiles nodes is enabled.
+ // in that case we assigned the node to that EP but do not compile it into a fused node.
+ // this keeps the original node and prevents level 2 and level 3 optimizers from modifying it.
+ // we now revert to the CPU EP kernel as a fallback.
+ // at runtime when the model is loaded in a minimal build, the compiling EP will replace this node if possible.
+ // if that's not possible for some reason we can fallback to the CPU EP implementation.
+ //
+ // 2. If the node is a memcpy node.
+ // EPs may provide their own memcpy kernels. The CPU EP provides a generic version to fall back to if the EP does
+ // not provide one.
+ const bool allow_cpu_ep_kernel_fallback = saving_ort_format || utils::IsMemcpyNode(node);
+
+ if (!status.IsOK() && allow_cpu_ep_kernel_fallback) {
node.SetExecutionProviderType(kCpuExecutionProvider);
status = kernel_registry_manager.SearchKernelRegistry(node, logger_, &kci);
}
diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc
index 2c0a51f0bfdbc..ca64c7c7cae89 100644
--- a/onnxruntime/core/framework/utils.cc
+++ b/onnxruntime/core/framework/utils.cc
@@ -46,22 +46,13 @@ void DestroyStrings(void* p_data, int64_t elements) {
ptr[i].~string();
}
-bool ProviderIsCpuBased(const std::string& provider_type) {
- return provider_type == onnxruntime::kCpuExecutionProvider ||
- provider_type == onnxruntime::kDnnlExecutionProvider ||
- provider_type == onnxruntime::kVitisAIExecutionProvider ||
- provider_type == onnxruntime::kOpenVINOExecutionProvider ||
- provider_type == onnxruntime::kNnapiExecutionProvider ||
- provider_type == onnxruntime::kVSINPUExecutionProvider ||
- provider_type == onnxruntime::kAclExecutionProvider ||
- provider_type == onnxruntime::kArmNNExecutionProvider ||
- provider_type == onnxruntime::kRknpuExecutionProvider ||
- provider_type == onnxruntime::kCoreMLExecutionProvider ||
- provider_type == onnxruntime::kSnpeExecutionProvider ||
- provider_type == onnxruntime::kQnnExecutionProvider ||
- provider_type == onnxruntime::kXnnpackExecutionProvider ||
- provider_type == onnxruntime::kAzureExecutionProvider ||
- provider_type == onnxruntime::utils::kInternalTestingExecutionProvider;
+bool ProviderIsCpuBased(const IExecutionProvider& provider) {
+ return provider.GetDevice().Type() == OrtDevice::CPU;
+}
+
+bool IsMemcpyNode(const Node& node) {
+ return node.Domain() == kOnnxDomain &&
+ (node.OpType() == "MemcpyFromHost" || node.OpType() == "MemcpyToHost");
}
static common::Status AllocateHelper(const AllocatorPtr& allocator,
@@ -210,7 +201,7 @@ static Status BatchOrCopyMLValue(const SessionState& session_state,
static bool HaveCpuExecutionProvidersOnly(const ExecutionProviders& execution_providers) {
for (const auto& execution_provider : execution_providers) {
- if (!ProviderIsCpuBased(execution_provider->Type())) {
+ if (!ProviderIsCpuBased(*execution_provider)) {
return false;
}
}
diff --git a/onnxruntime/core/framework/utils.h b/onnxruntime/core/framework/utils.h
index 6b5c404e26b7f..796a17b0406da 100644
--- a/onnxruntime/core/framework/utils.h
+++ b/onnxruntime/core/framework/utils.h
@@ -52,12 +52,10 @@ void DestroyStrings(void* p_data, int64_t elements);
const std::string& GetNodeInputProviderType(const SessionState::NodeInfo& info);
-// EP used for internal testing. We define it here as it's used in ProviderIsCpuBased, but we don't want
-// it to be in the public header include/onnxruntime/core/graph/constants.h as it's purely internal.
-constexpr const char* kInternalTestingExecutionProvider = "InternalTestingExecutionProvider";
-
// return true if the execution provider is CPU based (meaning no copies to device are required)
-bool ProviderIsCpuBased(const std::string& provider_type);
+bool ProviderIsCpuBased(const IExecutionProvider& provider);
+
+bool IsMemcpyNode(const Node& node);
common::Status CopyOneInputAcrossDevices(const SessionState& session_state, const std::string& input_name,
const OrtValue& orig_mlvalue, OrtValue& new_mlvalue);
diff --git a/onnxruntime/core/optimizer/transformer_memcpy.cc b/onnxruntime/core/optimizer/transformer_memcpy.cc
index cc7682b2b418d..9d49c16391f78 100644
--- a/onnxruntime/core/optimizer/transformer_memcpy.cc
+++ b/onnxruntime/core/optimizer/transformer_memcpy.cc
@@ -1,7 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-#include "transformer_memcpy.h"
+#include "core/optimizer/transformer_memcpy.h"
+
#include "core/common/logging/logging.h"
#include "core/framework/kernel_registry_manager.h"
#include "core/framework/execution_providers.h"
@@ -12,18 +13,39 @@
using namespace ONNX_NAMESPACE;
namespace onnxruntime {
+static ProviderTypeToProviderMap GetProvidersByType(
+ const InlinedVector>& providers) {
+ ProviderTypeToProviderMap providers_by_type{};
+ for (const auto provider : providers) {
+ providers_by_type.emplace(provider->Type(), provider);
+ }
+ return providers_by_type;
+}
+
+MemcpyTransformer::MemcpyTransformer(InlinedVector> providers,
+ const KernelRegistryManager& registry_manager)
+ : GraphTransformer("MemcpyTransformer"),
+ providers_(std::move(providers)),
+ providers_by_type_(GetProvidersByType(providers_)),
+ registry_manager_(std::cref(registry_manager)) {
+}
+
// implements MemCpy node insertion in graph transform
// note that GraphTransformer::Apply() is supposed to be stateless, so this cannot derive from GraphTransformer
class TransformerMemcpyImpl {
public:
- TransformerMemcpyImpl(onnxruntime::Graph& graph, const std::string& provider)
- : graph_(graph), provider_(provider) {}
+ TransformerMemcpyImpl(onnxruntime::Graph& graph, const IExecutionProvider& provider,
+ const ProviderTypeToProviderMap& providers_by_type)
+ : graph_(graph), provider_(provider), providers_by_type_(providers_by_type) {
+ }
bool ModifyGraph(const KernelRegistryManager& schema_registries,
const logging::Logger& logger,
int& copy_node_counter);
private:
+ bool IsNodeCompatibleWithProvider(const onnxruntime::Node& node) const;
+
void ProcessDefs(onnxruntime::Node& node,
const KernelRegistryManager& kernel_registries,
InitializedTensorSet& initializers_consumed,
@@ -31,7 +53,9 @@ class TransformerMemcpyImpl {
void BuildDefsMapping(const onnxruntime::NodeArg* arg,
const KernelRegistryManager& kernel_registries,
const logging::Logger& logger);
- void AddCopyNode(onnxruntime::NodeArg* arg, bool is_input, const logging::Logger& logger);
+ void AddCopyNode(onnxruntime::NodeArg* arg,
+ bool is_input,
+ const logging::Logger& logger);
bool ProcessInitializers(const KernelRegistryManager& kernel_registries,
const InitializedTensorSet& initializers_consumed,
const logging::Logger& logger);
@@ -55,7 +79,8 @@ class TransformerMemcpyImpl {
std::map> provider_output_nodes_;
onnxruntime::Graph& graph_;
- std::string provider_;
+ const IExecutionProvider& provider_;
+ const ProviderTypeToProviderMap& providers_by_type_;
};
/** Helper that returns a pointer to the corresponding TensorProto for a name if it is an initializer.
@@ -73,17 +98,18 @@ static const onnx::TensorProto* GetInitializer(const Graph& graph, const std::st
// very simple GraphTransformer that uses TransformerMemcpyImpl for each graph
// and mainly provides the subgraph recursion functionality
-common::Status MemcpyTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level,
- const logging::Logger& logger) const {
- for (auto& provider : provider_types_) {
- if (!utils::ProviderIsCpuBased(provider)) {
- TransformerMemcpyImpl copy_impl(graph, provider);
+Status MemcpyTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level,
+ const logging::Logger& logger) const {
+ for (const auto provider : providers_) {
+ const auto& provider_type = provider->Type();
+ if (!utils::ProviderIsCpuBased(*provider)) {
+ TransformerMemcpyImpl copy_impl(graph, *provider, providers_by_type_);
int copy_node_counter = 0;
auto current_modified = copy_impl.ModifyGraph(registry_manager_, logger, copy_node_counter);
- if (copy_node_counter > 0 && provider == kCudaExecutionProvider) {
+ if (copy_node_counter > 0 && provider_type == kCudaExecutionProvider) {
LOGS(logger, WARNING) << copy_node_counter << " Memcpy nodes are added to the graph " << graph.Name()
- << " for " << provider
+ << " for " << provider_type
<< ". It might have negative impact on performance (including unable to run CUDA graph). "
<< "Set session_options.log_severity_level=1 to see the detail logs before this message.";
}
@@ -213,15 +239,42 @@ bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_regi
return modified;
}
+static const IExecutionProvider* FindProviderByType(ProviderTypeToProviderMap providers_by_type,
+ std::string_view provider_type) {
+ const auto it = providers_by_type.find(provider_type);
+ if (it != providers_by_type.end()) {
+ return &*it->second;
+ }
+ return nullptr;
+}
+
+bool TransformerMemcpyImpl::IsNodeCompatibleWithProvider(const onnxruntime::Node& node) const {
+ const auto& node_provider_type = node.GetExecutionProviderType();
+ const auto* node_provider = FindProviderByType(providers_by_type_, node_provider_type);
+ ORT_ENFORCE(node_provider != nullptr, "Unable to get provider associated with provider type ", node_provider_type);
+
+ // Same provider?
+ if (node_provider->Type() == provider_.Type()) {
+ return true;
+ }
+
+ const auto& node_provider_device = node_provider->GetDevice();
+ const auto& provider_device = provider_.GetDevice();
+
+ // Same provider device type and vendor?
+ if (node_provider_device.Type() == provider_device.Type() &&
+ node_provider_device.Vendor() == provider_device.Vendor()) {
+ return true;
+ }
+
+ return false;
+}
+
void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node,
const KernelRegistryManager& kernel_registries,
InitializedTensorSet& initializers_consumed,
const logging::Logger& logger) {
- auto node_provider_type = node.GetExecutionProviderType();
- if ((node_provider_type == provider_) ||
- (node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_) ||
- (node_provider_type == kCudaExecutionProvider && kNvTensorRTRTXExecutionProvider == provider_) ||
- (node_provider_type == kRocmExecutionProvider && kMIGraphXExecutionProvider == provider_)) {
+ if (IsNodeCompatibleWithProvider(node)) {
provider_nodes_.insert(&node);
// note KernelCreateInfo might be nullptr for custom kernel
const KernelCreateInfo* kci = nullptr;
@@ -268,9 +321,7 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node,
else
provider_output_defs_.insert(arg);
}
- } else if (node_provider_type != kCudaExecutionProvider && node_provider_type != kTensorrtExecutionProvider &&
- node_provider_type != kCudaExecutionProvider && node_provider_type != kNvTensorRTRTXExecutionProvider &&
- node_provider_type != kRocmExecutionProvider && node_provider_type != kMIGraphXExecutionProvider) {
+ } else {
for (const auto* arg : node.InputDefs()) {
if (arg->Exists())
non_provider_input_defs_.insert(arg);
@@ -297,7 +348,7 @@ void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg,
const KernelRegistryManager& kernel_registries,
const logging::Logger& logger) {
for (auto& it : graph_.Nodes()) {
- if (it.OpType() == "MemcpyFromHost" || it.OpType() == "MemcpyToHost") continue;
+ if (utils::IsMemcpyNode(it)) continue;
auto input_it =
std::find(it.MutableInputDefs().begin(), it.MutableInputDefs().end(), const_cast(arg));
auto output_it =
@@ -309,10 +360,7 @@ void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg,
if (arg_input_index == -1 && arg_output_index == -1)
continue;
auto node_provider_type = it.GetExecutionProviderType();
- if ((node_provider_type == provider_) ||
- (node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_) ||
- (node_provider_type == kCudaExecutionProvider && kNvTensorRTRTXExecutionProvider == provider_) ||
- (node_provider_type == kRocmExecutionProvider && kMIGraphXExecutionProvider == provider_)) {
+ if (IsNodeCompatibleWithProvider(it)) {
const KernelCreateInfo* kci = nullptr;
ORT_IGNORE_RETURN_VALUE(kernel_registries.SearchKernelRegistry(it, logger, &kci));
if (arg_input_index != -1) {
@@ -325,9 +373,11 @@ void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg,
}
}
-void TransformerMemcpyImpl::AddCopyNode(onnxruntime::NodeArg* arg, bool is_input, const logging::Logger& logger) {
+void TransformerMemcpyImpl::AddCopyNode(onnxruntime::NodeArg* arg,
+ bool is_input,
+ const logging::Logger& logger) {
// create unique name for new def
- std::string new_def_name = graph_.GenerateNodeArgName(arg->Name() + "_" + provider_);
+ std::string new_def_name = graph_.GenerateNodeArgName(arg->Name() + "_" + provider_.Type());
auto* new_arg = &graph_.GetOrCreateNodeArg(new_def_name, arg->TypeAsProto());
auto* src_arg = is_input ? arg : new_arg;
@@ -338,12 +388,14 @@ void TransformerMemcpyImpl::AddCopyNode(onnxruntime::NodeArg* arg, bool is_input
const auto op_name = is_input ? "MemcpyFromHost" : "MemcpyToHost";
LOGS(logger, INFO) << "Add " << op_name << (is_input ? " after " : " before ") << arg->Name()
- << " for " << provider_;
+ << " for " << provider_.Type();
auto& new_node = graph_.AddNode(new_node_name, op_name, "Copy from/to host memory",
std::vector{src_arg},
std::vector{dst_arg});
- new_node.SetExecutionProviderType(provider_);
+
+ new_node.SetExecutionProviderType(provider_.Type());
+
std::map map = {{arg, new_arg}};
auto it = provider_input_nodes_.find(arg);
if (it != provider_input_nodes_.end()) {
diff --git a/onnxruntime/core/optimizer/transformer_memcpy.h b/onnxruntime/core/optimizer/transformer_memcpy.h
index a2403d269f89b..f6b60a83fcf32 100644
--- a/onnxruntime/core/optimizer/transformer_memcpy.h
+++ b/onnxruntime/core/optimizer/transformer_memcpy.h
@@ -5,13 +5,19 @@
#include
+#include "gsl/gsl"
+
#include "core/common/common.h"
+#include "core/common/inlined_containers.h"
+#include "core/framework/execution_provider.h"
#include "core/framework/op_kernel.h"
#include "core/framework/kernel_registry_manager.h"
#include "core/optimizer/graph_transformer.h"
namespace onnxruntime {
+using ProviderTypeToProviderMap = InlinedHashMap>;
+
/**
@Class MemcpyTransformer
@@ -19,13 +25,14 @@ Transformer that inserts nodes to copy memory between devices when needed.
*/
class MemcpyTransformer : public GraphTransformer {
public:
- MemcpyTransformer(const std::vector& provider_types, const KernelRegistryManager& registry_manager)
- : GraphTransformer("MemcpyTransformer"), provider_types_(provider_types), registry_manager_(std::cref(registry_manager)) {}
+ MemcpyTransformer(InlinedVector> providers,
+ const KernelRegistryManager& registry_manager);
private:
- common::Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
+ Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
- const std::vector provider_types_;
+ const InlinedVector> providers_;
+ const ProviderTypeToProviderMap providers_by_type_;
std::reference_wrapper registry_manager_;
};
diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
index 5eac0523d953a..1030e368a5fd6 100644
--- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
+++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
@@ -4,6 +4,7 @@
#include "core/providers/cpu/cpu_execution_provider.h"
#include "core/framework/allocator_utils.h"
+#include "core/framework/memcpy.h"
#include "core/framework/op_kernel.h"
#include "core/framework/kernel_registry.h"
#include "core/framework/int4.h"
@@ -27,6 +28,33 @@ struct KernelRegistryAndStatus {
} // namespace
namespace onnxruntime {
+
+// The MemcpyFromHost and MemcpyToHost kernels registered for the CPU EP are generic memcpy kernels.
+// Other EPs may provide their own memcpy kernels.
+// For a memcpy between host (CPU) and device of some other EP:
+// - If the EP provides the corresponding memcpy kernel, it will be used.
+// - Otherwise, one of these generic memcpy kernels will be used.
+
+ONNX_OPERATOR_KERNEL_EX(
+ MemcpyFromHost,
+ kOnnxDomain,
+ 1,
+ kCpuExecutionProvider,
+ (*KernelDefBuilder::Create())
+ .InputMemoryType(OrtMemTypeCPUInput, 0)
+ .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorAndSequenceTensorTypesIRv9()),
+ Memcpy);
+
+ONNX_OPERATOR_KERNEL_EX(
+ MemcpyToHost,
+ kOnnxDomain,
+ 1,
+ kCpuExecutionProvider,
+ (*KernelDefBuilder::Create())
+ .OutputMemoryType(OrtMemTypeCPUOutput, 0)
+ .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorAndSequenceTensorTypesIRv9()),
+ Memcpy);
+
CPUExecutionProvider::CPUExecutionProvider(const CPUExecutionProviderInfo& info)
: IExecutionProvider{onnxruntime::kCpuExecutionProvider}, info_{info} {}
@@ -39,6 +67,8 @@ std::vector CPUExecutionProvider::CreatePreferredAllocators() {
}
// Forward declarations of op kernels
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MemcpyFromHost);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MemcpyToHost);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 10, Clip);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 21, Elu);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 21, HardSigmoid);
@@ -1379,6 +1409,8 @@ KernelCreateInfo BuildKernelCreateInfo() {
Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn function_table[] = {
BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/core/providers/cpu/llm/attention.cc b/onnxruntime/core/providers/cpu/llm/attention.cc
index 4fc3cdf961d17..9d3736173dae2 100644
--- a/onnxruntime/core/providers/cpu/llm/attention.cc
+++ b/onnxruntime/core/providers/cpu/llm/attention.cc
@@ -47,14 +47,14 @@ void make_copy(MLFloat16* mask_data, const MLFloat16* mask
template <>
void make_copy(float* mask_data, const bool* mask_index, size_t size) {
for (size_t i = 0; i < size; ++i) {
- mask_data[i] = mask_index[i] ? 0.0f : std::numeric_limits::lowest();
+ mask_data[i] = mask_index[i] ? 0.0f : negative_infinity();
}
}
template <>
void make_copy(MLFloat16* mask_data, const bool* mask_index, size_t size) {
for (size_t i = 0; i < size; ++i) {
- mask_data[i] = mask_index[i] ? MLFloat16(0.f) : std::numeric_limits::lowest();
+ mask_data[i] = mask_index[i] ? MLFloat16(0.f) : negative_infinity();
}
}
@@ -236,7 +236,7 @@ void AttentionBase::ComputeAttentionProbs(T* attention_probs,
mask_data = static_cast(allocated_ptr);
for (int s_i = 0; s_i < parameters.q_sequence_length; s_i++) {
for (int m_i = parameters.past_sequence_length + s_i + 1; m_i < parameters.total_sequence_length; m_i++) {
- mask_data[s_i * parameters.total_sequence_length + m_i] = std::numeric_limits::lowest();
+ mask_data[s_i * parameters.total_sequence_length + m_i] = negative_infinity();
}
}
delete_mask_data = true;
@@ -262,7 +262,7 @@ void AttentionBase::ComputeAttentionProbs(T* attention_probs,
for (int i = 0; i < n_iter; ++i) {
for (int s_i = 0; s_i < parameters.q_sequence_length; s_i++) {
for (int m_i = parameters.past_sequence_length + s_i + 1; m_i < parameters.total_sequence_length; m_i++) {
- mask_data[s_i * parameters.total_sequence_length + m_i + probs_matrix_size * i] = std::numeric_limits::lowest();
+ mask_data[s_i * parameters.total_sequence_length + m_i + probs_matrix_size * i] = negative_infinity();
}
}
}
@@ -317,7 +317,8 @@ void AttentionBase::ComputeAttentionProbs(T* attention_probs,
}
// handling GQA
- std::ptrdiff_t ki = batch_i * parameters.kv_num_heads + head_i % parameters.kv_num_heads;
+ std::ptrdiff_t head_ki = head_i * parameters.kv_num_heads / parameters.q_num_heads;
+ std::ptrdiff_t ki = batch_i * parameters.kv_num_heads + head_ki;
const T* k = K + k_input_chunk_length * ki;
if (nullptr != present_key) {
@@ -347,7 +348,7 @@ void AttentionBase::ComputeAttentionProbs(T* attention_probs,
alpha,
Q + q_input_chunk_length * parameters.q_num_heads * batch_i + head_i * parameters.head_size,
parameters.head_size * parameters.q_num_heads, // lda
- transposed_k ? K + k_input_chunk_length * parameters.kv_num_heads * batch_i + head_i * parameters.head_size : k,
+ transposed_k ? K + k_input_chunk_length * parameters.kv_num_heads * batch_i + head_ki * parameters.head_size : k,
transposed_k ? parameters.head_size * parameters.kv_num_heads : parameters.head_size, // ldb
beta,
output,
@@ -555,7 +556,8 @@ void AttentionBase::ComputeVxAttentionScore(T* output, // bu
// handling GQA
std::ptrdiff_t batch_i = i / num_heads;
std::ptrdiff_t head_i = i % num_heads;
- std::ptrdiff_t vi = batch_i * kv_num_heads + head_i % kv_num_heads;
+ std::ptrdiff_t head_vi = head_i * kv_num_heads / num_heads;
+ std::ptrdiff_t vi = batch_i * kv_num_heads + head_vi;
const T* v = V + v_input_chunk_length * vi;
if (nullptr != present_value) {
@@ -579,15 +581,15 @@ void AttentionBase::ComputeVxAttentionScore(T* output, // bu
// V is transposed but not QK. We use GemmEx with a different value for ldb.
math::GemmEx(CblasNoTrans,
CblasNoTrans,
- sequence_length, // M
- v_head_size, // N
- total_sequence_length, // K
- 1.f, // alpha
- attention_probs + attention_probs_offset, // QK
- total_sequence_length, // lda
- transposed_v ? V + head_i * v_head_size + v_input_chunk_length * kv_num_heads * batch_i : v, // V
- transposed_v ? v_head_size * kv_num_heads : v_head_size, // ldb
- 0.f, // beta
+ sequence_length, // M
+ v_head_size, // N
+ total_sequence_length, // K
+ 1.f, // alpha
+ attention_probs + attention_probs_offset, // QK
+ total_sequence_length, // lda
+ transposed_v ? V + head_vi * v_head_size + v_input_chunk_length * kv_num_heads * batch_i : v, // V
+ transposed_v ? v_head_size * kv_num_heads : v_head_size, // ldb
+ 0.f, // beta
output + ((batch_i * sequence_length * num_heads + head_i) * v_head_size),
v_head_size * num_heads, // ldc
nullptr);
diff --git a/onnxruntime/core/providers/cpu/llm/attention.h b/onnxruntime/core/providers/cpu/llm/attention.h
index 78889e48afb29..4fad6914f933d 100644
--- a/onnxruntime/core/providers/cpu/llm/attention.h
+++ b/onnxruntime/core/providers/cpu/llm/attention.h
@@ -9,6 +9,16 @@
namespace onnxruntime {
+template
+inline T negative_infinity() {
+ return -std::numeric_limits::infinity();
+}
+
+template <>
+inline MLFloat16 negative_infinity() {
+ return MLFloat16(-std::numeric_limits::infinity());
+}
+
template
class AttentionBase : public OpKernel {
public:
diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc
index 0ee18cc6799fc..62210d65848d1 100644
--- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc
+++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc
@@ -1466,12 +1466,15 @@ std::unique_ptr NvExecutionProvider::GetSubGraph(SubGraph_t gra
fused_inputs.erase(it);
erased.insert(output);
}
- // Only when output is neither in input list nor erased list, add the output to output list
+ // Only when output is neither in input list nor erased list, and the output is consumed by another node, add the output to output list
else if (erased.find(output) == erased.end()) {
if (graph_output_names.find(output->Name()) != graph_output_names.end()) {
graph_outputs_to_add[output] = output_order;
}
- fused_outputs[output] = output_order++;
+
+ if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) {
+ fused_outputs[output] = output_order++;
+ }
}
}
}
diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
index b60f64db1734d..508d932459bf9 100644
--- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
+++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
@@ -2114,12 +2114,15 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph
fused_inputs.erase(it);
erased.insert(output);
}
- // Only when output is neither in input list nor erased list, add the output to output list
+ // Only when output is neither in input list nor erased list, and the output is consumed by another node, add the output to output list
else if (erased.find(output) == erased.end()) {
if (graph_output_names.find(output->Name()) != graph_output_names.end()) {
graph_outputs_to_add[output] = output_order;
}
- fused_outputs[output] = output_order++;
+
+ if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) {
+ fused_outputs[output] = output_order++;
+ }
}
}
}
diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc
index c0900c5ad28a0..112fd84c5ed45 100644
--- a/onnxruntime/core/session/inference_session.cc
+++ b/onnxruntime/core/session/inference_session.cc
@@ -1517,12 +1517,12 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool
// Insert copy node/s.
{
- std::vector provider_types;
+ InlinedVector> providers;
for (auto& provider_ptr : execution_providers_) {
- provider_types.push_back(provider_ptr->Type());
+ providers.push_back(provider_ptr.get());
}
- MemcpyTransformer copy_transformer{provider_types, kernel_registry_manager_};
+ MemcpyTransformer copy_transformer{std::move(providers), kernel_registry_manager_};
ORT_RETURN_IF_ERROR_SESSIONID_(apply_transformer_once(copy_transformer, *session_logger_, graph));
}
@@ -3383,17 +3383,58 @@ common::Status InferenceSession::GetInputOutputMemoryInfo(SessionInputOutputType
for (const auto* def : def_list) {
InlinedVector node_info_vec;
+ Status status;
if (type == SessionInputOutputType::kOutput) {
- ORT_RETURN_IF_ERROR(session_state_->GetOutputNodeInfo(def->Name(), node_info_vec));
+ status = session_state_->GetOutputNodeInfo(def->Name(), node_info_vec);
} else {
- ORT_RETURN_IF_ERROR(session_state_->GetInputNodeInfo(def->Name(), node_info_vec));
+ status = session_state_->GetInputNodeInfo(def->Name(), node_info_vec);
}
- // all entries are for the same OrtDevice so use the first one.
- // we need to get an OrtMemoryInfo* that will remain valid, so we get the allocator for the OrtDevice
- // from the session state and use its OrtMemoryInfo.
- auto allocator = session_state_->GetAllocator(*node_info_vec.front().device);
- memory_info.push_back(&allocator->Info());
+ if (!status.IsOK()) {
+ if (type == SessionInputOutputType::kInput) {
+ return status;
+ }
+
+ // Check first if this output is produced by an input that directly
+ // propagates to output with the same name.
+ status = session_state_->GetInputNodeInfo(def->Name(), node_info_vec);
+ if (status.IsOK()) {
+ // all entries are for the same OrtDevice so use the first one.
+ // we need to get an OrtMemoryInfo* that will remain valid, so we get the allocator for the OrtDevice
+ // from the session state and use its OrtMemoryInfo.
+ auto allocator = session_state_->GetAllocator(*node_info_vec.front().device);
+ memory_info.push_back(&allocator->Info());
+ } else {
+ // Check if this output is produced by a constant initializer
+ // Pick the MemoryInfo from the initializer's OrtValue
+ const auto& ort_value_map = session_state_->GetOrtValueNameIdxMap();
+
+ OrtValueIndex ort_value_index;
+ status = ort_value_map.GetIdx(def->Name(), ort_value_index);
+ if (!status.IsOK()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
+ "Failed to find node output or a constant initializer producing output: ",
+ def->Name(), ".");
+ }
+
+ const auto& idx_to_ort_value = session_state_->GetInitializedTensors();
+ auto it = idx_to_ort_value.find(ort_value_index);
+ if (it == idx_to_ort_value.end()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
+ "Failed to find node output or a constant initializer producing output: ",
+ def->Name(), ".");
+ }
+ const auto& tensor = it->second.Get();
+ auto allocator = session_state_->GetAllocator(tensor.Location());
+ memory_info.push_back(&allocator->Info());
+ }
+ } else {
+ // all entries are for the same OrtDevice so use the first one.
+ // we need to get an OrtMemoryInfo* that will remain valid, so we get the allocator for the OrtDevice
+ // from the session state and use its OrtMemoryInfo.
+ auto allocator = session_state_->GetAllocator(*node_info_vec.front().device);
+ memory_info.push_back(&allocator->Info());
+ }
}
return Status::OK();
@@ -3422,15 +3463,19 @@ common::Status InferenceSession::GetEpDeviceForInputs(InlinedVector node_info_vec;
ORT_RETURN_IF_ERROR(session_state_->GetInputNodeInfo(def->Name(), node_info_vec));
-
- // if we have a lot of inputs or there are a lot of execution providers it may be worth creating a map
- // instead of doing a linear search each time.
- const auto& ep_name = node_info_vec.front().p_node->GetExecutionProviderType();
- auto it = std::find_if(available_eps.begin(), available_eps.end(), [&ep_name](const OrtEpDevice* entry) {
- return entry->ep_name == ep_name;
- });
-
- ep_devices.push_back(it != available_eps.end() ? *it : nullptr);
+ assert(!node_info_vec.empty());
+ // If we have an input that is not consumed by any node,
+ // including nodes in subgraphs, then we return nullptr.
+ const auto* p_node = node_info_vec.front().p_node;
+ if (p_node != nullptr) {
+ const auto ep_name = p_node->GetExecutionProviderType();
+ auto it = std::find_if(available_eps.begin(), available_eps.end(), [&ep_name](const OrtEpDevice* entry) {
+ return entry->ep_name == ep_name;
+ });
+ ep_devices.push_back(it != available_eps.end() ? *it : nullptr);
+ } else {
+ ep_devices.push_back(nullptr);
+ }
}
return Status::OK();
diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc
index 21d09df5cc4db..edc0cb6d2bd0f 100644
--- a/onnxruntime/core/session/onnxruntime_c_api.cc
+++ b/onnxruntime/core/session/onnxruntime_c_api.cc
@@ -4257,7 +4257,7 @@ static_assert(offsetof(OrtApi, SetEpDynamicOptions) / sizeof(void*) == 284, "Siz
static_assert(offsetof(OrtApi, GetEpApi) / sizeof(void*) == 317, "Size of version 22 API cannot change");
// So that nobody forgets to finish an API version, this check will serve as a reminder:
-static_assert(std::string_view(ORT_VERSION) == "1.23.0",
+static_assert(std::string_view(ORT_VERSION) == "1.23.1",
"ORT_Version change detected, please follow below steps to ensure OrtApi is updated properly");
// 1. Update the hardcoded version string in above static_assert to silence it
// 2. If there were any APIs added to ort_api_1_to_23 above:
diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py
index 35abad5760c32..4c3313046457c 100644
--- a/onnxruntime/python/onnxruntime_inference_collection.py
+++ b/onnxruntime/python/onnxruntime_inference_collection.py
@@ -199,6 +199,18 @@ def get_modelmeta(self) -> onnxruntime.ModelMetadata:
"Return the metadata. See :class:`onnxruntime.ModelMetadata`."
return self._model_meta
+ def get_input_memory_infos(self) -> Sequence[onnxruntime.MemoryInfo]:
+ "Return the memory info for the inputs."
+ return self._input_meminfos
+
+ def get_output_memory_infos(self) -> Sequence[onnxruntime.MemoryInfo]:
+ "Return the memory info for the outputs."
+ return self._output_meminfos
+
+ def get_input_epdevices(self) -> Sequence[onnxruntime.OrtEpDevice]:
+ "Return the execution providers for the inputs."
+ return self._input_epdevices
+
def get_providers(self) -> Sequence[str]:
"Return list of registered execution providers."
return self._providers
@@ -576,6 +588,9 @@ def _create_inference_session(self, providers, provider_options, disabled_optimi
self._inputs_meta = self._sess.inputs_meta
self._outputs_meta = self._sess.outputs_meta
self._overridable_initializers = self._sess.overridable_initializers
+ self._input_meminfos = self._sess.input_meminfos
+ self._output_meminfos = self._sess.output_meminfos
+ self._input_epdevices = self._sess.input_epdevices
self._model_meta = self._sess.model_meta
self._providers = self._sess.get_providers()
self._provider_options = self._sess.get_provider_options()
@@ -589,6 +604,9 @@ def _reset_session(self, providers, provider_options) -> None:
self._inputs_meta = None
self._outputs_meta = None
self._overridable_initializers = None
+ self._input_meminfos = None
+ self._output_meminfos = None
+ self._input_epdevices = None
self._model_meta = None
self._providers = None
self._provider_options = None
@@ -1134,6 +1152,15 @@ def update_inplace(self, np_arr) -> None:
self._ortvalue.update_inplace(np_arr)
+def copy_tensors(src: Sequence[OrtValue], dst: Sequence[OrtValue], stream=None) -> None:
+ """
+ Copy tensor data from source OrtValue sequence to destination OrtValue sequence.
+ """
+ c_sources = [s._get_c_value() for s in src]
+ c_dsts = [d._get_c_value() for d in dst]
+ C.copy_tensors(c_sources, c_dsts, stream)
+
+
class OrtDevice:
"""
A data structure that exposes the underlying C++ OrtDevice
@@ -1146,6 +1173,7 @@ def __init__(self, c_ort_device):
if isinstance(c_ort_device, C.OrtDevice):
self._ort_device = c_ort_device
else:
+ # An end user won't hit this error
raise ValueError(
"`Provided object` needs to be of type `onnxruntime.capi.onnxruntime_pybind11_state.OrtDevice`"
)
@@ -1188,6 +1216,9 @@ def device_type(self):
def device_vendor_id(self):
return self._ort_device.vendor_id()
+ def device_mem_type(self):
+ return self._ort_device.mem_type()
+
class SparseTensor:
"""
diff --git a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc
index 1fe7ab0884f9c..d74663ddb63d7 100644
--- a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc
+++ b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc
@@ -333,7 +333,7 @@ void addOrtValueMethods(pybind11::module& m) {
})
#endif
// Get a pointer to Tensor data
- .def("data_ptr", [](OrtValue* ml_value) -> int64_t {
+ .def("data_ptr", [](OrtValue* ml_value) -> uintptr_t {
// TODO: Assumes that the OrtValue is a Tensor, make this generic to handle non-Tensors
ORT_ENFORCE(ml_value->IsTensor(), "Only OrtValues that are Tensors are currently supported");
@@ -344,7 +344,7 @@ void addOrtValueMethods(pybind11::module& m) {
}
// Should cover x86 and x64 platforms
- return reinterpret_cast(tensor->MutableDataRaw());
+ return reinterpret_cast(tensor->MutableDataRaw());
})
.def("device_name", [](const OrtValue* ort_value) -> std::string {
if (ort_value->IsTensor()) {
diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc
index e370518b1fffb..c17acc9ffff3a 100644
--- a/onnxruntime/python/onnxruntime_pybind_state.cc
+++ b/onnxruntime/python/onnxruntime_pybind_state.cc
@@ -22,6 +22,7 @@
#include "core/framework/data_transfer_utils.h"
#include "core/framework/data_types_internal.h"
#include "core/framework/error_code_helper.h"
+#include "core/framework/plugin_ep_stream.h"
#include "core/framework/provider_options_utils.h"
#include "core/framework/random_seed.h"
#include "core/framework/sparse_tensor.h"
@@ -1584,6 +1585,18 @@ void addGlobalMethods(py::module& m) {
},
R"pbdoc("Validate a compiled model's compatibility information for one or more EP devices.)pbdoc");
+ m.def(
+ "copy_tensors",
+ [](const std::vector& src, const std::vector& dest, py::object& py_arg) {
+ const OrtEnv* ort_env = GetOrtEnv();
+ OrtSyncStream* stream = nullptr;
+ if (!py_arg.is_none()) {
+ stream = py_arg.cast();
+ }
+ Ort::ThrowOnError(Ort::GetApi().CopyTensors(ort_env, src.data(), dest.data(), stream, src.size()));
+ },
+ R"pbdoc("Copy tensors from sources to destinations using specified stream handle (or None))pbdoc");
+
#if defined(USE_OPENVINO) || defined(USE_OPENVINO_PROVIDER_INTERFACE)
m.def(
"get_available_openvino_device_ids", []() -> std::vector {
@@ -1785,6 +1798,16 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra
.value("CPU", OrtMemTypeCPU)
.value("DEFAULT", OrtMemTypeDefault);
+ py::enum_(m, "OrtMemoryInfoDeviceType")
+ .value("CPU", OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_CPU)
+ .value("GPU", OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU)
+ .value("NPU", OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_NPU)
+ .value("FPGA", OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_FPGA);
+
+ py::enum_(m, "OrtDeviceMemoryType")
+ .value("DEFAULT", OrtDeviceMemoryType_DEFAULT)
+ .value("HOST_ACCESSIBLE", OrtDeviceMemoryType_HOST_ACCESSIBLE);
+
py::class_ device(m, "OrtDevice", R"pbdoc(ONNXRuntime device information.)pbdoc");
device.def(py::init())
.def(py::init([](OrtDevice::DeviceType type,
@@ -1813,6 +1836,7 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra
.def("device_id", &OrtDevice::Id, R"pbdoc(Device Id.)pbdoc")
.def("device_type", &OrtDevice::Type, R"pbdoc(Device Type.)pbdoc")
.def("vendor_id", &OrtDevice::Vendor, R"pbdoc(Vendor Id.)pbdoc")
+ .def("mem_type", &OrtDevice::MemType, R"pbdoc(Device Memory Type.)pbdoc")
// generic device types that are typically used with a vendor id.
.def_static("cpu", []() { return OrtDevice::CPU; })
.def_static("gpu", []() { return OrtDevice::GPU; })
@@ -1863,36 +1887,55 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra
},
R"pbdoc(Hardware device's metadata as string key/value pairs.)pbdoc");
+ py::class_ py_sync_stream(m, "OrtSyncStream",
+ R"pbdoc(Represents a synchronization stream for model inference.)pbdoc");
+
py::class_ py_ep_device(m, "OrtEpDevice",
R"pbdoc(Represents a hardware device that an execution provider supports
for model inference.)pbdoc");
py_ep_device.def_property_readonly(
"ep_name",
- [](OrtEpDevice* ep_device) -> std::string { return ep_device->ep_name; },
+ [](const OrtEpDevice* ep_device) -> std::string { return ep_device->ep_name; },
R"pbdoc(The execution provider's name.)pbdoc")
.def_property_readonly(
"ep_vendor",
- [](OrtEpDevice* ep_device) -> std::string { return ep_device->ep_vendor; },
+ [](const OrtEpDevice* ep_device) -> std::string { return ep_device->ep_vendor; },
R"pbdoc(The execution provider's vendor name.)pbdoc")
.def_property_readonly(
"ep_metadata",
- [](OrtEpDevice* ep_device) -> std::map {
+ [](const OrtEpDevice* ep_device) -> std::map {
return ep_device->ep_metadata.Entries();
},
R"pbdoc(The execution provider's additional metadata for the OrtHardwareDevice.)pbdoc")
.def_property_readonly(
"ep_options",
- [](OrtEpDevice* ep_device) -> std::map {
+ [](const OrtEpDevice* ep_device) -> std::map {
return ep_device->ep_options.Entries();
},
R"pbdoc(The execution provider's options used to configure the provider to use the OrtHardwareDevice.)pbdoc")
.def_property_readonly(
"device",
- [](OrtEpDevice* ep_device) -> const OrtHardwareDevice& {
+ [](const OrtEpDevice* ep_device) -> const OrtHardwareDevice& {
return *ep_device->device;
},
R"pbdoc(The OrtHardwareDevice instance for the OrtEpDevice.)pbdoc",
- py::return_value_policy::reference_internal);
+ py::return_value_policy::reference_internal)
+ .def(
+ "memory_info",
+ [](const OrtEpDevice* ep_device, OrtDeviceMemoryType memory_type) -> const OrtMemoryInfo* {
+ Ort::ConstEpDevice ep_dev(ep_device);
+ return static_cast(ep_dev.GetMemoryInfo(memory_type));
+ },
+ R"pbdoc(The OrtMemoryInfo instance for the OrtEpDevice specific to the device memory type.)pbdoc",
+ py::return_value_policy::reference_internal)
+ .def(
+ "create_sync_stream",
+ [](const OrtEpDevice* ep_device) -> std::unique_ptr {
+ Ort::ConstEpDevice ep_dev(ep_device);
+ Ort::SyncStream stream = ep_dev.CreateSyncStream();
+ return std::unique_ptr(stream.release());
+ },
+ R"pbdoc(The OrtSyncStream instance for the OrtEpDevice.)pbdoc");
py::class_ ort_arena_cfg_binding(m, "OrtArenaCfg");
// Note: Doesn't expose initial_growth_chunk_sizes_bytes/max_power_of_two_extend_bytes option.
@@ -1938,25 +1981,28 @@ for model inference.)pbdoc");
.def_readwrite("max_power_of_two_extend_bytes", &OrtArenaCfg::max_power_of_two_extend_bytes);
py::class_ ort_memory_info_binding(m, "OrtMemoryInfo");
- ort_memory_info_binding.def(py::init([](const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) {
- if (strcmp(name, onnxruntime::CPU) == 0) {
- return std::make_unique(onnxruntime::CPU, type, OrtDevice(), mem_type);
- } else if (strcmp(name, onnxruntime::CUDA) == 0) {
- return std::make_unique(
- onnxruntime::CUDA, type,
- OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA,
- static_cast(id)),
- mem_type);
- } else if (strcmp(name, onnxruntime::CUDA_PINNED) == 0) {
- return std::make_unique(
- onnxruntime::CUDA_PINNED, type,
- OrtDevice(OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::NVIDIA,
- static_cast(id)),
- mem_type);
- } else {
- throw std::runtime_error("Specified device is not supported.");
- }
- }));
+ ort_memory_info_binding.def(
+ py::init([](const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) {
+ Ort::MemoryInfo result(name, type, id, mem_type);
+ return std::unique_ptr(result.release());
+ }))
+ .def_static(
+ "create_v2",
+ [](const char* name, OrtMemoryInfoDeviceType device_type, uint32_t vendor_id,
+ int32_t device_id, OrtDeviceMemoryType device_mem_type, size_t alignment, OrtAllocatorType type) {
+ Ort::MemoryInfo result(name, device_type, vendor_id, device_id, device_mem_type, alignment, type);
+ return std::unique_ptr(result.release());
+ },
+ R"pbdoc(Create an OrtMemoryInfo instance using CreateMemoryInfo_V2())pbdoc")
+ .def_property_readonly("name", [](const OrtMemoryInfo* mem_info) -> std::string { return mem_info->name; }, R"pbdoc(Arbitrary name supplied by the user)pbdoc")
+ .def_property_readonly("device_id", [](const OrtMemoryInfo* mem_info) -> int { return mem_info->device.Id(); }, R"pbdoc(Device Id.)pbdoc")
+ .def_property_readonly("mem_type", [](const OrtMemoryInfo* mem_info) -> OrtMemType { return mem_info->mem_type; }, R"pbdoc(OrtMemoryInfo memory type.)pbdoc")
+ .def_property_readonly("allocator_type", [](const OrtMemoryInfo* mem_info) -> OrtAllocatorType { return mem_info->alloc_type; }, R"pbdoc(Allocator type)pbdoc")
+ .def_property_readonly("device_mem_type", [](const OrtMemoryInfo* mem_info) -> OrtDeviceMemoryType {
+ auto mem_type = mem_info->device.MemType();
+ return (mem_type == OrtDevice::MemType::DEFAULT) ?
+ OrtDeviceMemoryType_DEFAULT: OrtDeviceMemoryType_HOST_ACCESSIBLE ; }, R"pbdoc(Device memory type (Device or Host accessible).)pbdoc")
+ .def_property_readonly("device_vendor_id", [](const OrtMemoryInfo* mem_info) -> uint32_t { return mem_info->device.Vendor(); });
py::class_
sess(m, "SessionOptions", R"pbdoc(Configuration information for a session.)pbdoc");
@@ -2653,6 +2699,33 @@ including arg name, arg type (contains both type and shape).)pbdoc")
auto res = sess->GetSessionHandle()->GetModelMetadata();
OrtPybindThrowIfError(res.first);
return *(res.second); }, py::return_value_policy::reference_internal)
+ .def_property_readonly("input_meminfos", [](const PyInferenceSession* sess) -> py::list {
+ Ort::ConstSession session(reinterpret_cast(sess->GetSessionHandle()));
+ auto inputs_mem_info = session.GetMemoryInfoForInputs();
+ py::list result;
+ for (const auto& info : inputs_mem_info) {
+ const auto* p_info = static_cast(info);
+ result.append(py::cast(p_info, py::return_value_policy::reference));
+ }
+ return result; })
+ .def_property_readonly("output_meminfos", [](const PyInferenceSession* sess) -> py::list {
+ Ort::ConstSession session(reinterpret_cast(sess->GetSessionHandle()));
+ auto outputs_mem_info = session.GetMemoryInfoForOutputs();
+ py::list result;
+ for (const auto& info : outputs_mem_info) {
+ const auto* p_info = static_cast(info);
+ result.append(py::cast(p_info, py::return_value_policy::reference));
+ }
+ return result; })
+ .def_property_readonly("input_epdevices", [](const PyInferenceSession* sess) -> py::list {
+ Ort::ConstSession session(reinterpret_cast(sess->GetSessionHandle()));
+ auto ep_devices = session.GetEpDeviceForInputs();
+ py::list result;
+ for (const auto& device : ep_devices) {
+ const auto* p_device = static_cast(device);
+ result.append(py::cast(p_device, py::return_value_policy::reference));
+ }
+ return result; })
.def("run_with_iobinding", [](PyInferenceSession* sess, SessionIOBinding& io_binding, RunOptions* run_options = nullptr) -> void {
Status status;
diff --git a/onnxruntime/test/autoep/library/ep.cc b/onnxruntime/test/autoep/library/ep.cc
index e4265713d2d0a..5d8245618dcd6 100644
--- a/onnxruntime/test/autoep/library/ep.cc
+++ b/onnxruntime/test/autoep/library/ep.cc
@@ -233,7 +233,7 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG
for (const auto& node : nodes) {
auto op_type = node.GetOperatorType();
- if (op_type != "Mul") {
+ if (op_type == "Mul") {
// Check that Mul has inputs/output of type float
std::vector inputs = node.GetInputs();
std::vector outputs = node.GetOutputs();
@@ -248,11 +248,36 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG
continue; // Input or output is not of type float
}
+ {
+ const auto input_0_shape = GetTensorShape(inputs[0]),
+ input_1_shape = GetTensorShape(inputs[1]);
+
+ if (!input_0_shape.has_value() || !input_1_shape.has_value()) {
+ continue; // unable to get input shape
+ }
+
+ const auto is_static_shape = [](gsl::span shape) -> bool {
+ return std::all_of(shape.begin(), shape.end(), [](int64_t dim) { return dim >= 0; });
+ };
+
+ if (!is_static_shape(*input_0_shape) || !is_static_shape(*input_1_shape)) {
+ continue; // input shape has dynamic dimensions
+ }
+
+ if (*input_0_shape != *input_1_shape) {
+ continue; // input shapes do not match (no broadcasting support for now)
+ }
+ }
+
supported_nodes.push_back(node); // Only support a single Mul for now.
break;
}
}
+ if (supported_nodes.empty()) {
+ return nullptr;
+ }
+
// Create (optional) fusion options for the supported nodes to fuse.
OrtNodeFusionOptions node_fusion_options = {};
node_fusion_options.ort_version_supported = ORT_API_VERSION;
@@ -317,7 +342,7 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const
Ort::ConstNode fused_node{fused_nodes[0]};
auto ep_name = fused_node.GetEpName();
- if (ep_name != "example_ep") {
+ if (ep_name != ep->name_) {
Ort::Status status("The fused node is expected to assigned to this EP to run on", ORT_EP_FAIL);
return status.release();
}
diff --git a/onnxruntime/test/autoep/library/ep_stream_support.cc b/onnxruntime/test/autoep/library/ep_stream_support.cc
index 1f6c16a8cb358..c648474d4fad7 100644
--- a/onnxruntime/test/autoep/library/ep_stream_support.cc
+++ b/onnxruntime/test/autoep/library/ep_stream_support.cc
@@ -61,7 +61,12 @@ OrtStatus* ORT_API_CALL NotificationImpl::ActivateImpl(_In_ OrtSyncNotificationI
/*static*/
OrtStatus* ORT_API_CALL NotificationImpl::WaitOnDeviceImpl(_In_ OrtSyncNotificationImpl* this_ptr,
_In_ OrtSyncStream* stream) noexcept {
+ if (stream == nullptr) {
+ return nullptr;
+ }
+
auto& impl = *static_cast(this_ptr);
+
void* handle = impl.ort_api.SyncStream_GetHandle(stream);
static_cast(handle);
diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_utils.cc b/onnxruntime/test/autoep/library/example_plugin_ep_utils.cc
index 263b4d208bd91..8b36f5f4e9a13 100644
--- a/onnxruntime/test/autoep/library/example_plugin_ep_utils.cc
+++ b/onnxruntime/test/autoep/library/example_plugin_ep_utils.cc
@@ -35,3 +35,14 @@ void IsFloatTensor(Ort::ConstValueInfo value_info, bool& result) {
}
result = true;
}
+
+std::optional> GetTensorShape(Ort::ConstValueInfo value_info) {
+ const auto type_info = value_info.TypeInfo();
+ const auto onnx_type = type_info.GetONNXType();
+ if (onnx_type != ONNX_TYPE_TENSOR) {
+ return std::nullopt;
+ }
+
+ const auto type_shape = type_info.GetTensorTypeAndShapeInfo();
+ return type_shape.GetShape();
+}
diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_utils.h b/onnxruntime/test/autoep/library/example_plugin_ep_utils.h
index e8c086d38a7cb..decc89251dc7b 100644
--- a/onnxruntime/test/autoep/library/example_plugin_ep_utils.h
+++ b/onnxruntime/test/autoep/library/example_plugin_ep_utils.h
@@ -4,7 +4,9 @@
#pragma once
#include
+#include
#include
+#include
#define ORT_API_MANUAL_INIT
#include "onnxruntime_cxx_api.h"
@@ -108,3 +110,6 @@ OrtStatus* GetSessionConfigEntryOrDefault(const OrtApi& ort_api, const OrtSessio
// Returns true (via output parameter) if the given OrtValueInfo represents a float tensor.
void IsFloatTensor(Ort::ConstValueInfo value_info, bool& result);
+
+// Gets the tensor shape from `value_info`. Returns std::nullopt if `value_info` is not a tensor.
+std::optional> GetTensorShape(Ort::ConstValueInfo value_info);
diff --git a/onnxruntime/test/autoep/test_execution.cc b/onnxruntime/test/autoep/test_execution.cc
index 0f4a654f116c4..78be22d082692 100644
--- a/onnxruntime/test/autoep/test_execution.cc
+++ b/onnxruntime/test/autoep/test_execution.cc
@@ -22,7 +22,8 @@ namespace onnxruntime {
namespace test {
namespace {
-void RunModelWithPluginEp(Ort::SessionOptions& session_options) {
+
+void RunMulModelWithPluginEp(const Ort::SessionOptions& session_options) {
Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), session_options);
// Create input
@@ -47,6 +48,38 @@ void RunModelWithPluginEp(Ort::SessionOptions& session_options) {
gsl::span output_span(output_data, 6);
EXPECT_THAT(output_span, ::testing::ElementsAre(2, 4, 6, 8, 10, 12));
}
+
+void RunPartiallySupportedModelWithPluginEp(const Ort::SessionOptions& session_options) {
+ // This model has Add -> Mul -> Add. The example plugin EP only supports Mul.
+ Ort::Session session(*ort_env, ORT_TSTR("testdata/add_mul_add.onnx"), session_options);
+
+ // Create inputs
+ Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
+ std::vector shape = {3, 2};
+
+ std::vector a_data{1, 2, 3, 4, 5, 6};
+ std::vector b_data{2, 3, 4, 5, 6, 7};
+
+ std::vector ort_inputs{};
+ ort_inputs.emplace_back(
+ Ort::Value::CreateTensor(memory_info, a_data.data(), a_data.size(), shape.data(), shape.size()));
+ ort_inputs.emplace_back(
+ Ort::Value::CreateTensor(memory_info, b_data.data(), b_data.size(), shape.data(), shape.size()));
+
+ std::array ort_input_names{"A", "B"};
+
+ // Run session and get outputs
+ std::array output_names{"C"};
+ std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(),
+ ort_inputs.size(), output_names.data(), output_names.size());
+
+ // Check expected output values
+ Ort::Value& ort_output = ort_outputs[0];
+ const float* output_data = ort_output.GetTensorData();
+ gsl::span output_span(output_data, 6);
+ EXPECT_THAT(output_span, ::testing::ElementsAre(7, 17, 31, 49, 71, 97));
+}
+
} // namespace
// Creates a session with the example plugin EP and runs a model with a single Mul node.
@@ -61,7 +94,7 @@ TEST(OrtEpLibrary, PluginEp_AppendV2_MulInference) {
std::unordered_map ep_options;
session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options);
- RunModelWithPluginEp(session_options);
+ RunMulModelWithPluginEp(session_options);
}
// Creates a session with the example plugin EP and runs a model with a single Mul node.
@@ -74,10 +107,23 @@ TEST(OrtEpLibrary, PluginEp_PreferCpu_MulInference) {
// PREFER_CPU pick our example EP over ORT CPU EP. TODO: Actually assert this.
Ort::SessionOptions session_options;
session_options.SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy_PREFER_CPU);
- RunModelWithPluginEp(session_options);
+ RunMulModelWithPluginEp(session_options);
}
}
+TEST(OrtEpLibrary, PluginEp_AppendV2_PartiallySupportedModelInference) {
+ RegisteredEpDeviceUniquePtr example_ep;
+ Utils::RegisterAndGetExampleEp(*ort_env, example_ep);
+ Ort::ConstEpDevice plugin_ep_device(example_ep.get());
+
+ // Create session with example plugin EP
+ Ort::SessionOptions session_options;
+ std::unordered_map ep_options;
+ session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options);
+
+ RunPartiallySupportedModelWithPluginEp(session_options);
+}
+
// Generate an EPContext model with a plugin EP.
// This test uses the OrtCompileApi but could also be done by setting the appropriate session option configs.
TEST(OrtEpLibrary, PluginEp_GenEpContextModel) {
@@ -98,6 +144,7 @@ TEST(OrtEpLibrary, PluginEp_GenEpContextModel) {
// Create model compilation options from the session options.
Ort::ModelCompilationOptions compile_options(*ort_env, session_options);
+ compile_options.SetFlags(OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED);
compile_options.SetInputModelPath(input_model_file);
compile_options.SetOutputModelPath(output_model_file);
diff --git a/onnxruntime/test/framework/memcpy_transformer_test.cc b/onnxruntime/test/framework/memcpy_transformer_test.cc
index 6e86e5b58aead..01b253446974b 100644
--- a/onnxruntime/test/framework/memcpy_transformer_test.cc
+++ b/onnxruntime/test/framework/memcpy_transformer_test.cc
@@ -71,8 +71,18 @@ void ExpectCopy(const onnxruntime::Node& source, const std::string copy_op,
}
EXPECT_TRUE(false) << "Copy node expected but not found";
}
+
#ifdef USE_CUDA
+static InlinedVector> GetNotNullProviderPtrs(
+ const ExecutionProviders& providers) {
+ InlinedVector> not_null_provider_ptrs{};
+ for (auto& provider_ptr : providers) {
+ not_null_provider_ptrs.emplace_back(provider_ptr.get());
+ }
+ return not_null_provider_ptrs;
+}
+
TEST(TransformerTest, MemcpyTransformerTest) {
std::unordered_map domain_to_version;
domain_to_version[kOnnxDomain] = 7;
@@ -112,7 +122,11 @@ TEST(TransformerTest, MemcpyTransformerTest) {
KernelRegistryManager test_registry_manager;
ASSERT_STATUS_OK(test_registry_manager.RegisterKernels(execution_providers));
- MemcpyTransformer transformer({onnxruntime::kCudaExecutionProvider}, test_registry_manager);
+ InlinedVector> providers;
+ for (auto& provider_ptr : execution_providers) {
+ providers.push_back(provider_ptr.get());
+ }
+ MemcpyTransformer transformer(GetNotNullProviderPtrs(execution_providers), test_registry_manager);
bool modified = false;
status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger());
@@ -167,7 +181,7 @@ TEST(TransformerTest, MemcpyTransformerTestCudaFirst) {
KernelRegistryManager test_registry_manager;
ASSERT_STATUS_OK(test_registry_manager.RegisterKernels(execution_providers));
- MemcpyTransformer transformer({onnxruntime::kCudaExecutionProvider}, test_registry_manager);
+ MemcpyTransformer transformer(GetNotNullProviderPtrs(execution_providers), test_registry_manager);
bool modified = false;
status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger());
@@ -262,6 +276,8 @@ TEST(TransformerTest, TestInitializerDuplicationInSubgraph) {
if_node.AddAttribute("then_branch", subgraph.ToGraphProto());
if_node.AddAttribute("else_branch", subgraph.ToGraphProto());
+ if_node.SetExecutionProviderType(onnxruntime::kCpuExecutionProvider);
+
onnxruntime::Graph* subgraph_1 = if_node.GetMutableGraphAttribute("then_branch");
for (auto& node : subgraph_1->Nodes()) {
if (node.Name() == "node2") {
@@ -287,7 +303,7 @@ TEST(TransformerTest, TestInitializerDuplicationInSubgraph) {
KernelRegistryManager test_registry_manager;
ASSERT_STATUS_OK(test_registry_manager.RegisterKernels(execution_providers));
- MemcpyTransformer transformer({onnxruntime::kCudaExecutionProvider}, test_registry_manager);
+ MemcpyTransformer transformer(GetNotNullProviderPtrs(execution_providers), test_registry_manager);
bool modified = false;
ASSERT_STATUS_OK(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()));
@@ -329,7 +345,7 @@ TEST(TransformerTest, MemcpyTransformerTestGraphInputConsumedOnMultipleDevices)
KernelRegistryManager test_registry_manager;
ASSERT_STATUS_OK(test_registry_manager.RegisterKernels(execution_providers));
- MemcpyTransformer transformer({onnxruntime::kCudaExecutionProvider}, test_registry_manager);
+ MemcpyTransformer transformer(GetNotNullProviderPtrs(execution_providers), test_registry_manager);
bool modified = false;
status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger());
@@ -398,6 +414,8 @@ TEST(TransformerTest, MemcpyTransformerTestImplicitInputConsumedOnMultipleDevice
if_node.AddAttribute("then_branch", subgraph.ToGraphProto());
if_node.AddAttribute("else_branch", subgraph.ToGraphProto());
+ if_node.SetExecutionProviderType(onnxruntime::kCpuExecutionProvider);
+
graph.SetInputs({&i1_def, &i2_def});
onnxruntime::Graph* subgraph_1 = if_node.GetMutableGraphAttribute("then_branch");
@@ -431,7 +449,7 @@ TEST(TransformerTest, MemcpyTransformerTestImplicitInputConsumedOnMultipleDevice
KernelRegistryManager test_registry_manager;
ASSERT_STATUS_OK(test_registry_manager.RegisterKernels(execution_providers));
- MemcpyTransformer transformer({onnxruntime::kCudaExecutionProvider}, test_registry_manager);
+ MemcpyTransformer transformer(GetNotNullProviderPtrs(execution_providers), test_registry_manager);
bool modified = false;
status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger());
diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc
index bef0bdd5295be..d56212510d2a9 100644
--- a/onnxruntime/test/onnx/main.cc
+++ b/onnxruntime/test/onnx/main.cc
@@ -795,6 +795,24 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
// Please make no more changes to the list
static const ORTCHAR_T* immutable_broken_tests[] =
{
+ // pending ONNX update
+ ORT_TSTR("attention_3d_gqa"),
+ ORT_TSTR("attention_3d_gqa_attn_mask"),
+ ORT_TSTR("attention_3d_gqa_causal"),
+ ORT_TSTR("attention_3d_gqa_scaled"),
+ ORT_TSTR("attention_3d_gqa_softcap"),
+ ORT_TSTR("attention_3d_gqa_with_past_and_present"),
+ ORT_TSTR("attention_4d_gqa"),
+ ORT_TSTR("attention_4d_gqa_attn_mask"),
+ ORT_TSTR("attention_4d_gqa_causal"),
+ ORT_TSTR("attention_4d_gqa_scaled"),
+ ORT_TSTR("attention_4d_gqa_softcap"),
+ ORT_TSTR("attention_4d_gqa_with_past_and_present"),
+ ORT_TSTR("attention_4d_diff_heads_mask4d_padded_kv"),
+ ORT_TSTR("attention_4d_gqa_with_past_and_present_fp16"),
+ ORT_TSTR("attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_causal"),
+ ORT_TSTR("attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal"),
+ // unsupported case
ORT_TSTR("AvgPool1d"),
ORT_TSTR("AvgPool1d_stride"),
ORT_TSTR("AvgPool2d"),
diff --git a/onnxruntime/test/optimizer/transpose_optimizer_test.cc b/onnxruntime/test/optimizer/transpose_optimizer_test.cc
index f6fce37322c10..c382612a6dff8 100644
--- a/onnxruntime/test/optimizer/transpose_optimizer_test.cc
+++ b/onnxruntime/test/optimizer/transpose_optimizer_test.cc
@@ -4518,7 +4518,7 @@ TEST(TransposeOptimizerTests, QnnTransposeReshape) {
// changes during the layout transformation process.
ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kDebugLayoutTransformation, "1"));
- using InternalTestingEP = onnxruntime::internal_testing_ep::InternalTestingExecutionProvider;
+ using InternalTestingEP = internal_testing_ep::InternalTestingExecutionProvider;
// set the test EP to support all ops in the model so that the layout transform applies to all nodes
const std::unordered_set empty_set;
@@ -4544,7 +4544,7 @@ TEST(TransposeOptimizerTests, QnnTransposeReshape) {
"with the exception of the initial node prior to the Conv";
// all nodes should be assigned to the internal testing EP, which also means they should be in NHWC layout
- std::string expected_ep(onnxruntime::utils::kInternalTestingExecutionProvider);
+ std::string expected_ep(internal_testing_ep::kInternalTestingExecutionProvider);
for (const auto& node : graph.Nodes()) {
EXPECT_TRUE(node.GetExecutionProviderType() == expected_ep) << node.OpType() << " node named '" << node.Name()
<< "' was not assigned to the internal testing EP.";
@@ -4564,7 +4564,7 @@ TEST(TransposeOptimizerTests, QnnTransposeReshapeQDQ) {
SessionOptions so;
- using InternalTestingEP = onnxruntime::internal_testing_ep::InternalTestingExecutionProvider;
+ using InternalTestingEP = internal_testing_ep::InternalTestingExecutionProvider;
// set the test EP to support all ops in the model so that the layout transform applies to all nodes
const std::unordered_set empty_set;
@@ -4590,7 +4590,7 @@ TEST(TransposeOptimizerTests, QnnTransposeReshapeQDQ) {
"with the exception of the initial node prior to the Conv";
// all nodes should be assigned to the internal testing EP, which also means they should be in NHWC layout
- std::string expected_ep(onnxruntime::utils::kInternalTestingExecutionProvider);
+ std::string expected_ep(internal_testing_ep::kInternalTestingExecutionProvider);
for (const auto& node : graph.Nodes()) {
EXPECT_TRUE(node.GetExecutionProviderType() == expected_ep) << node.OpType() << " node named '" << node.Name()
<< "' was not assigned to the internal testing EP.";
@@ -4606,7 +4606,7 @@ TEST(TransposeOptimizerTests, QnnResizeOpset11) {
// Uncomment to debug
// ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kDebugLayoutTransformation, "1"));
- using InternalTestingEP = onnxruntime::internal_testing_ep::InternalTestingExecutionProvider;
+ using InternalTestingEP = internal_testing_ep::InternalTestingExecutionProvider;
// set the test EP to support all ops in the model so that the layout transform applies to all nodes
const std::unordered_set empty_set;
@@ -4620,7 +4620,7 @@ TEST(TransposeOptimizerTests, QnnResizeOpset11) {
const auto& graph = session.GetGraph();
// all nodes should be assigned to the internal testing EP, which also means they should be in NHWC layout
- std::string expected_ep(onnxruntime::utils::kInternalTestingExecutionProvider);
+ std::string expected_ep(internal_testing_ep::kInternalTestingExecutionProvider);
for (const auto& node : graph.Nodes()) {
EXPECT_TRUE(node.GetExecutionProviderType() == expected_ep) << node.OpType() << " node named '" << node.Name()
<< "' was not assigned to the internal testing EP.";
@@ -4648,7 +4648,7 @@ TEST(TransposeOptimizerTests, QnnTransposeNonConstBroadcastInput) {
// ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kDebugLayoutTransformation, "1"));
- using InternalTestingEP = onnxruntime::internal_testing_ep::InternalTestingExecutionProvider;
+ using InternalTestingEP = internal_testing_ep::InternalTestingExecutionProvider;
// set the test EP to support all ops in the model so that the layout transform applies to all nodes
const std::unordered_set empty_set;
@@ -4666,7 +4666,7 @@ TEST(TransposeOptimizerTests, QnnTransposeNonConstBroadcastInput) {
ASSERT_EQ(op_to_count["Transpose"], 3) << "Should have Transpose on 2 inputs and one on output.";
// all nodes should be assigned to the internal testing EP, which also means they should be in NHWC layout
- std::string expected_ep(onnxruntime::utils::kInternalTestingExecutionProvider);
+ std::string expected_ep(internal_testing_ep::kInternalTestingExecutionProvider);
for (const auto& node : graph.Nodes()) {
EXPECT_EQ(node.GetExecutionProviderType(), expected_ep) << node.OpType() << " node named '" << node.Name()
<< "' was not assigned to the internal testing EP.";
@@ -4699,7 +4699,7 @@ TEST(TransposeOptimizerTests, LayoutTransformFixStuckTransposeWithoutDQ) {
// ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kDebugLayoutTransformation, "1"));
- using InternalTestingEP = onnxruntime::internal_testing_ep::InternalTestingExecutionProvider;
+ using InternalTestingEP = internal_testing_ep::InternalTestingExecutionProvider;
// Set the test EP to support all ops in the model so that the layout transform applies to all nodes
const std::unordered_set empty_set;
@@ -4716,7 +4716,7 @@ TEST(TransposeOptimizerTests, LayoutTransformFixStuckTransposeWithoutDQ) {
ASSERT_EQ(op_to_count["Transpose"], 2) << "Should have 2 transposes remaining.";
- std::string expected_ep(onnxruntime::utils::kInternalTestingExecutionProvider);
+ std::string expected_ep(internal_testing_ep::kInternalTestingExecutionProvider);
for (const auto& node : graph.Nodes()) {
EXPECT_EQ(node.GetExecutionProviderType(), expected_ep) << node.OpType() << " node named '" << node.Name()
<< "' was not assigned to the internal testing EP.";
@@ -4756,7 +4756,7 @@ TEST(TransposeOptimizerTests, LayoutTransformConstantFoldTransposeAndSqueeze) {
// ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kDebugLayoutTransformation, "1"));
- using InternalTestingEP = onnxruntime::internal_testing_ep::InternalTestingExecutionProvider;
+ using InternalTestingEP = internal_testing_ep::InternalTestingExecutionProvider;
// Set the test EP to support all ops in the model so that the layout transform applies to all nodes
const std::unordered_set empty_set;
@@ -4777,7 +4777,7 @@ TEST(TransposeOptimizerTests, LayoutTransformConstantFoldTransposeAndSqueeze) {
// 1 transpose is constant-folded, 1 is canceled, and 1 remains.
ASSERT_EQ(op_to_count["Transpose"], 1) << "Should have 1 transpose remaining.";
- std::string expected_ep(onnxruntime::utils::kInternalTestingExecutionProvider);
+ std::string expected_ep(internal_testing_ep::kInternalTestingExecutionProvider);
for (const auto& node : graph.Nodes()) {
EXPECT_EQ(node.GetExecutionProviderType(), expected_ep) << node.OpType() << " node named '" << node.Name()
<< "' was not assigned to the internal testing EP.";
diff --git a/onnxruntime/test/providers/cpu/llm/attention_op_test.cc b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc
index b4f6d328cacf7..54c2ed7d521db 100644
--- a/onnxruntime/test/providers/cpu/llm/attention_op_test.cc
+++ b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc
@@ -664,6 +664,7 @@ TEST(AttentionTest, Attention4DAttnPastPresent) {
false, true, true // disable_cpu, disable_cuda, disable_dml
);
}
+
TEST(AttentionTest, Attention4DAttnIsCausal) {
int batch_size = 2; // Q.shape[0]
int q_num_heads = 3; // Q.shape[1]
@@ -828,6 +829,38 @@ TEST(AttentionTest, Attention4DDiffHeadsWithPastAndPresent) {
);
}
+TEST(AttentionTest, Attention3DGqaAttn) {
+ int batch_size = 2; // Q.shape[0]
+ int q_num_heads = 9; // Q.shape[1]
+ int q_sequence_length = 4; // Q.shape[2]
+ int head_size = 8; // Q.shape[3]
+ int kv_sequence_length = 6; // K.shape[2] and V.shape[2]
+ int kv_num_heads = 3; // K.shape[1] and V.shape[1]
+ int v_head_size = 8; // V.shape[3]
+ int past_sequence_length = 0; // past_key.shape[2] and past_value.shape[2]
+
+ // {2, 4, 72}
+ std::vector q = {0.548814f, 0.715189f, 0.602763f, 0.544883f, 0.423655f, 0.645894f, 0.437587f, 0.891773f, 0.963663f, 0.383442f, 0.791725f, 0.528895f, 0.568045f, 0.925597f, 0.071036f, 0.087129f, 0.020218f, 0.832620f, 0.778157f, 0.870012f, 0.978618f, 0.799159f, 0.461479f, 0.780529f, 0.118274f, 0.639921f, 0.143353f, 0.944669f, 0.521848f, 0.414662f, 0.264556f, 0.774234f, 0.456150f, 0.568434f, 0.018790f, 0.617635f, 0.612096f, 0.616934f, 0.943748f, 0.681820f, 0.359508f, 0.437032f, 0.697631f, 0.060225f, 0.666767f, 0.670638f, 0.210383f, 0.128926f, 0.315428f, 0.363711f, 0.570197f, 0.438602f, 0.988374f, 0.102045f, 0.208877f, 0.161310f, 0.653108f, 0.253292f, 0.466311f, 0.244426f, 0.158970f, 0.110375f, 0.656330f, 0.138183f, 0.196582f, 0.368725f, 0.820993f, 0.097101f, 0.837945f, 0.096098f, 0.976459f, 0.468651f, 0.976761f, 0.604846f, 0.739264f, 0.039188f, 0.282807f, 0.120197f, 0.296140f, 0.118728f, 0.317983f, 0.414263f, 0.064147f, 0.692472f, 0.566601f, 0.265390f, 0.523248f, 0.093941f, 0.575947f, 0.929296f, 0.318569f, 0.667410f, 0.131798f, 0.716327f, 0.289406f, 0.183191f, 0.586513f, 0.020108f, 0.828940f, 0.004695f, 0.677817f, 0.270008f, 0.735194f, 0.962189f, 0.248753f, 0.576157f, 0.592042f, 0.572252f, 0.223082f, 0.952749f, 0.447125f, 0.846409f, 0.699479f, 0.297437f, 0.813798f, 0.396506f, 0.881103f, 0.581273f, 0.881735f, 0.692532f, 0.725254f, 0.501324f, 0.956084f, 0.643990f, 0.423855f, 0.606393f, 0.019193f, 0.301575f, 0.660174f, 0.290078f, 0.618015f, 0.428769f, 0.135474f, 0.298282f, 0.569965f, 0.590873f, 0.574325f, 0.653201f, 0.652103f, 0.431418f, 0.896547f, 0.367562f, 0.435865f, 0.891923f, 0.806194f, 0.703889f, 0.100227f, 0.919483f, 0.714241f, 0.998847f, 0.149448f, 0.868126f, 0.162493f, 0.615560f, 0.123820f, 0.848008f, 0.807319f, 0.569101f, 0.407183f, 0.069167f, 0.697429f, 0.453543f, 0.722056f, 0.866382f, 0.975522f, 0.855803f, 0.011714f, 0.359978f, 0.729991f, 0.171630f, 0.521037f, 0.054338f, 0.199997f, 0.018522f, 0.793698f, 0.223925f, 0.345352f, 0.928081f, 0.704414f, 0.031839f, 0.164694f, 0.621478f, 0.577229f, 0.237893f, 0.934214f, 0.613966f, 0.535633f, 0.589910f, 0.730122f, 0.311945f, 0.398221f, 0.209844f, 0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f, 0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f};
+ // {2, 6, 24}
+ std::vector k = {0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f, 0.966147f, 0.791880f, 0.675689f, 0.244889f, 0.216457f, 0.166048f, 0.922757f, 0.294077f, 0.453094f, 0.493958f, 0.778172f, 0.844235f, 0.139073f, 0.426904f, 0.842855f, 0.818033f, 0.102414f, 0.156383f, 0.304199f, 0.075359f, 0.424663f, 0.107618f, 0.568218f, 0.246557f, 0.596433f, 0.117526f, 0.975884f, 0.932561f, 0.391797f, 0.242179f, 0.250398f, 0.483394f, 0.039993f, 0.639705f, 0.408303f, 0.377407f, 0.809365f, 0.709035f, 0.954334f, 0.351936f, 0.897543f, 0.769967f, 0.357425f, 0.621665f, 0.288570f, 0.874400f, 0.112427f, 0.212434f, 0.183033f, 0.403026f, 0.745233f, 0.526907f, 0.487676f, 0.000546f, 0.425402f, 0.063554f, 0.208253f, 0.932394f, 0.215398f, 0.858338f, 0.802893f, 0.159146f, 0.605712f, 0.115662f, 0.727888f, 0.637462f, 0.811939f, 0.479385f, 0.914863f, 0.049349f, 0.292889f, 0.715053f, 0.418109f, 0.172951f, 0.107211f, 0.817339f, 0.473143f, 0.882284f, 0.733289f, 0.409726f, 0.373511f, 0.515638f, 0.889060f, 0.737279f, 0.005153f, 0.694158f, 0.919507f, 0.710456f, 0.177006f, 0.483518f, 0.140316f, 0.358995f, 0.937117f, 0.923305f, 0.282837f, 0.339631f};
+ // {2, 6, 24}
+ std::vector v = {0.600213f, 0.963197f, 0.147801f, 0.256917f, 0.873557f, 0.491892f, 0.898961f, 0.185518f, 0.532669f, 0.326270f, 0.316543f, 0.446877f, 0.433077f, 0.357347f, 0.914971f, 0.731744f, 0.727547f, 0.289913f, 0.577709f, 0.779179f, 0.795590f, 0.344530f, 0.770873f, 0.735894f, 0.141506f, 0.865945f, 0.441321f, 0.486410f, 0.448369f, 0.567846f, 0.621169f, 0.498180f, 0.866789f, 0.627735f, 0.401428f, 0.416692f, 0.810839f, 0.348192f, 0.211455f, 0.059383f, 0.876027f, 0.918546f, 0.120120f, 0.334474f, 0.175372f, 0.115898f, 0.899867f, 0.056877f, 0.980486f, 0.096451f, 0.863471f, 0.566506f, 0.367917f, 0.342342f, 0.757364f, 0.314573f, 0.657319f, 0.517326f, 0.484966f, 0.901162f, 0.554645f, 0.826862f, 0.725574f, 0.038557f, 0.773110f, 0.216870f, 0.903150f, 0.042924f, 0.333072f, 0.099733f, 0.475589f, 0.820022f, 0.298187f, 0.150935f, 0.330267f, 0.813880f, 0.140384f, 0.227362f, 0.068852f, 0.705710f, 0.395233f, 0.310840f, 0.718626f, 0.335978f, 0.727771f, 0.815199f, 0.217663f, 0.973819f, 0.162358f, 0.290841f, 0.179795f, 0.345506f, 0.480061f, 0.522176f, 0.853606f, 0.889448f, 0.220104f, 0.622894f, 0.111496f, 0.458970f, 0.322334f, 0.316501f, 0.482584f, 0.729828f, 0.069183f, 0.879173f, 0.734814f, 0.176499f, 0.939161f, 0.506312f, 0.999809f, 0.197259f, 0.534908f, 0.290248f, 0.304174f, 0.591065f, 0.921719f, 0.805264f, 0.723941f, 0.559174f, 0.922298f, 0.492361f, 0.873832f, 0.833982f, 0.213835f, 0.771225f, 0.012171f, 0.322830f, 0.229567f, 0.506863f, 0.736853f, 0.097676f, 0.514922f, 0.938412f, 0.228647f, 0.677141f, 0.592880f, 0.010064f, 0.475826f, 0.708770f, 0.043975f, 0.879521f, 0.520081f, 0.030661f, 0.224414f, 0.953676f, 0.582320f, 0.107473f, 0.287544f, 0.456704f, 0.020950f, 0.411616f, 0.489459f, 0.243678f, 0.588639f, 0.753240f, 0.235834f, 0.620500f, 0.639622f, 0.948540f, 0.778276f, 0.848345f, 0.490420f, 0.185349f, 0.995815f, 0.129356f, 0.471457f, 0.068093f, 0.943851f, 0.964925f, 0.719389f, 0.349993f, 0.254382f, 0.265303f, 0.127294f, 0.525809f, 0.141817f, 0.316731f, 0.626706f, 0.727544f, 0.024273f, 0.430116f, 0.652125f, 0.853246f, 0.475325f, 0.969206f, 0.265633f, 0.013509f, 0.483753f, 0.256114f, 0.823718f, 0.232773f, 0.310629f, 0.791227f, 0.715143f, 0.558051f, 0.704948f, 0.418637f, 0.005310f, 0.011355f, 0.511222f, 0.083291f, 0.051075f, 0.965517f, 0.859003f, 0.152027f, 0.000664f, 0.941668f, 0.278325f, 0.185898f, 0.691508f, 0.108904f, 0.264650f, 0.975095f, 0.639463f, 0.520678f, 0.397919f, 0.774501f, 0.140957f, 0.967338f, 0.861123f, 0.617657f, 0.042906f, 0.700856f, 0.913284f, 0.524577f, 0.354225f, 0.120277f, 0.754901f, 0.885022f, 0.100252f, 0.758985f, 0.017060f, 0.967055f, 0.615058f, 0.552439f, 0.295950f, 0.929292f, 0.265906f, 0.828147f, 0.985109f, 0.783397f, 0.518990f, 0.066074f, 0.472414f, 0.438256f, 0.202796f, 0.423588f, 0.357758f, 0.163684f, 0.441374f, 0.262800f, 0.522062f, 0.035160f, 0.906231f, 0.816364f, 0.552581f, 0.851809f, 0.962395f, 0.110522f, 0.630832f, 0.997994f, 0.987889f, 0.603323f, 0.128021f, 0.583193f, 0.002065f, 0.198911f, 0.956123f, 0.330441f, 0.638390f, 0.280860f, 0.947822f, 0.728559f, 0.329651f, 0.791761f, 0.108166f, 0.392319f, 0.221218f, 0.683726f, 0.102446f, 0.397026f, 0.276650f, 0.506343f, 0.349898f, 0.706411f, 0.024577f, 0.633987f};
+ // {2, 4, 72}
+ std::vector y = {0.532009f, 0.526025f, 0.449746f, 0.551692f, 0.407822f, 0.436275f, 0.507807f, 0.457324f, 0.530536f, 0.517111f, 0.452785f, 0.557318f, 0.397721f, 0.434161f, 0.498276f, 0.464536f, 0.528016f, 0.548671f, 0.441040f, 0.542961f, 0.418557f, 0.444397f, 0.515088f, 0.452512f, 0.462161f, 0.530536f, 0.564630f, 0.418701f, 0.669452f, 0.633554f, 0.569379f, 0.430544f, 0.456026f, 0.529795f, 0.558238f, 0.411985f, 0.664240f, 0.619959f, 0.590516f, 0.438577f, 0.471552f, 0.521718f, 0.560465f, 0.404206f, 0.663920f, 0.628819f, 0.540935f, 0.447763f, 0.615083f, 0.344791f, 0.432664f, 0.451253f, 0.460813f, 0.441267f, 0.708582f, 0.530088f, 0.623659f, 0.343547f, 0.439418f, 0.450767f, 0.460055f, 0.442001f, 0.703292f, 0.522883f, 0.617738f, 0.343160f, 0.440540f, 0.440079f, 0.459815f, 0.436860f, 0.703290f, 0.534856f, 0.536138f, 0.499439f, 0.465771f, 0.565138f, 0.391402f, 0.430258f, 0.494915f, 0.463613f, 0.532752f, 0.526358f, 0.452075f, 0.562130f, 0.402551f, 0.442784f, 0.486721f, 0.456955f, 0.547578f, 0.527342f, 0.453800f, 0.548887f, 0.418444f, 0.438968f, 0.515475f, 0.444207f, 0.475352f, 0.524010f, 0.549702f, 0.420030f, 0.656346f, 0.620729f, 0.571884f, 0.431010f, 0.453307f, 0.522210f, 0.563368f, 0.412061f, 0.657897f, 0.634999f, 0.577458f, 0.451691f, 0.473936f, 0.524285f, 0.553525f, 0.421768f, 0.662288f, 0.622833f, 0.570081f, 0.432808f, 0.625738f, 0.353159f, 0.436185f, 0.448597f, 0.459371f, 0.429822f, 0.709026f, 0.526207f, 0.630878f, 0.351036f, 0.439799f, 0.452249f, 0.456486f, 0.431906f, 0.706014f, 0.518897f, 0.629526f, 0.351482f, 0.440728f, 0.449287f, 0.451705f, 0.426815f, 0.706598f, 0.522028f, 0.537899f, 0.527199f, 0.447980f, 0.548688f, 0.410653f, 0.436181f, 0.511135f, 0.455244f, 0.534560f, 0.540045f, 0.447505f, 0.552786f, 0.413302f, 0.446360f, 0.499945f, 0.450757f, 0.531708f, 0.526097f, 0.450511f, 0.553372f, 0.401450f, 0.438186f, 0.501418f, 0.462466f, 0.469643f, 0.527539f, 0.553613f, 0.418159f, 0.659814f, 0.622731f, 0.575224f, 0.429425f, 0.463941f, 0.524481f, 0.557632f, 0.413729f, 0.657415f, 0.629157f, 0.570920f, 0.439773f, 0.479643f, 0.526773f, 0.556809f, 0.422406f, 0.670038f, 0.625300f, 0.554451f, 0.426587f, 0.630894f, 0.353011f, 0.444285f, 0.443177f, 0.448608f, 0.419312f, 0.705883f, 0.526260f, 0.631310f, 0.347563f, 0.445672f, 0.446224f, 0.448210f, 0.428481f, 0.702004f, 0.519990f, 0.626158f, 0.342802f, 0.449770f, 0.440666f, 0.453705f, 0.427492f, 0.700510f, 0.533279f, 0.526144f, 0.538202f, 0.443619f, 0.551579f, 0.407162f, 0.442426f, 0.499995f, 0.459987f, 0.525627f, 0.544718f, 0.448060f, 0.544942f, 0.415781f, 0.444198f, 0.516948f, 0.452985f, 0.521784f, 0.523083f, 0.450924f, 0.565538f, 0.392054f, 0.440702f, 0.479094f, 0.468113f, 0.473886f, 0.523677f, 0.555144f, 0.409412f, 0.664285f, 0.620163f, 0.555448f, 0.440947f, 0.459210f, 0.528829f, 0.567231f, 0.413602f, 0.672778f, 0.632467f, 0.565881f, 0.439895f, 0.480238f, 0.525127f, 0.554365f, 0.431656f, 0.658900f, 0.634358f, 0.561181f, 0.419623f, 0.646099f, 0.364754f, 0.442180f, 0.450340f, 0.441320f, 0.412523f, 0.708121f, 0.505939f, 0.641772f, 0.375478f, 0.428502f, 0.454772f, 0.439016f, 0.407773f, 0.718457f, 0.504047f, 0.628271f, 0.345239f, 0.449391f, 0.436208f, 0.448766f, 0.426444f, 0.699202f, 0.528374f, 0.489165f, 0.818278f, 0.467403f, 0.370507f, 0.572406f, 0.417942f, 0.160316f, 0.384139f, 0.497723f, 0.820329f, 0.455669f, 0.373132f, 0.568626f, 0.418602f, 0.164551f, 0.404233f, 0.488972f, 0.813399f, 0.460936f, 0.369774f, 0.580477f, 0.417018f, 0.167442f, 0.381535f, 0.603715f, 0.360599f, 0.371685f, 0.614777f, 0.440767f, 0.425124f, 0.369342f, 0.828101f, 0.584460f, 0.352249f, 0.382191f, 0.613073f, 0.431223f, 0.421802f, 0.389292f, 0.831202f, 0.590574f, 0.355658f, 0.373391f, 0.623741f, 0.432416f, 0.412097f, 0.378312f, 0.829226f, 0.365226f, 0.726961f, 0.549872f, 0.239494f, 0.496434f, 0.668542f, 0.557774f, 0.487281f, 0.361340f, 0.749156f, 0.523408f, 0.240555f, 0.493770f, 0.639516f, 0.552116f, 0.478230f, 0.367118f, 0.740114f, 0.563789f, 0.238852f, 0.498407f, 0.682064f, 0.571327f, 0.496416f, 0.480636f, 0.820258f, 0.464776f, 0.362168f, 0.567256f, 0.417842f, 0.161815f, 0.387104f, 0.486998f, 0.821507f, 0.467362f, 0.377934f, 0.569593f, 0.418367f, 0.156778f, 0.390179f, 0.461449f, 0.823726f, 0.471401f, 0.361646f, 0.563554f, 0.418609f, 0.154999f, 0.379696f, 0.565916f, 0.345293f, 0.392969f, 0.612305f, 0.418858f, 0.416238f, 0.410985f, 0.833515f, 0.552881f, 0.338985f, 0.394863f, 0.597100f, 0.422296f, 0.401025f, 0.427810f, 0.831702f, 0.558983f, 0.339943f, 0.393544f, 0.583418f, 0.432193f, 0.405729f, 0.426401f, 0.830305f, 0.362801f, 0.731181f, 0.546338f, 0.247016f, 0.499389f, 0.662441f, 0.544727f, 0.486631f, 0.355514f, 0.726998f, 0.518056f, 0.249475f, 0.492155f, 0.643678f, 0.531052f, 0.481617f, 0.370308f, 0.743741f, 0.562172f, 0.233361f, 0.498431f, 0.679567f, 0.580747f, 0.494199f, 0.481097f, 0.817782f, 0.461707f, 0.369188f, 0.573825f, 0.419752f, 0.161614f, 0.386708f, 0.472911f, 0.822003f, 0.473412f, 0.375830f, 0.569966f, 0.422158f, 0.149228f, 0.380008f, 0.454662f, 0.818956f, 0.465984f, 0.370169f, 0.575537f, 0.423344f, 0.153818f, 0.375466f, 0.572526f, 0.348075f, 0.380718f, 0.641409f, 0.417012f, 0.407621f, 0.389074f, 0.834251f, 0.581008f, 0.348183f, 0.383659f, 0.608061f, 0.435032f, 0.422240f, 0.393710f, 0.832528f, 0.600530f, 0.360439f, 0.371006f, 0.609018f, 0.441082f, 0.416286f, 0.374920f, 0.825853f, 0.364932f, 0.727047f, 0.540001f, 0.246375f, 0.501524f, 0.656266f, 0.541761f, 0.482865f, 0.360322f, 0.752650f, 0.542120f, 0.239561f, 0.491207f, 0.663446f, 0.566643f, 0.491988f, 0.364532f, 0.737402f, 0.546869f, 0.240953f, 0.497072f, 0.664793f, 0.558528f, 0.488182f, 0.490592f, 0.819727f, 0.468739f, 0.379671f, 0.572959f, 0.422399f, 0.152699f, 0.387445f, 0.462308f, 0.822644f, 0.463886f, 0.374320f, 0.569615f, 0.423238f, 0.152603f, 0.387850f, 0.451896f, 0.818576f, 0.449904f, 0.362889f, 0.573917f, 0.421849f, 0.165145f, 0.390440f, 0.565044f, 0.343397f, 0.395512f, 0.584043f, 0.431062f, 0.417783f, 0.421165f, 0.830938f, 0.583998f, 0.354061f, 0.374016f, 0.633981f, 0.424457f, 0.404069f, 0.381920f, 0.829920f, 0.568315f, 0.347357f, 0.386911f, 0.624227f, 0.418162f, 0.411256f, 0.400332f, 0.832994f, 0.370475f, 0.739716f, 0.551429f, 0.234114f, 0.499500f, 0.665245f, 0.570648f, 0.485298f, 0.364035f, 0.756092f, 0.542251f, 0.238706f, 0.495463f, 0.659518f, 0.567976f, 0.489204f, 0.368942f, 0.756397f, 0.548083f, 0.231854f, 0.496617f, 0.659726f, 0.578330f, 0.484921f};
+
+ ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size);
+ ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size);
+ ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size);
+ ASSERT_EQ(y.size(), batch_size * q_num_heads * q_sequence_length * v_head_size);
+
+ RunTest3D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length,
+ q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(),
+ -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type
+ y, std::vector(), std::vector(), std::vector(),
+ false, true, true // disable_cpu, disable_cuda, disable_dml
+ );
+}
+
TEST(AttentionTest, Attention4DGqaAttnMask) {
int batch_size = 2; // Q.shape[0]
int q_num_heads = 9; // Q.shape[1]
@@ -847,7 +880,7 @@ TEST(AttentionTest, Attention4DGqaAttnMask) {
// {4, 6}
std::vector m = {0.230571f, 0.268709f, 0.800256f, 0.955568f, 0.316550f, 0.826805f, 0.103991f, 0.633982f, 0.751032f, 0.155978f, 0.426002f, 0.892707f, 0.103578f, 0.018096f, 0.590585f, 0.435532f, 0.798689f, 0.923456f, 0.299154f, 0.388404f, 0.486272f, 0.588151f, 0.983854f, 0.697330f};
// {2, 9, 4, 8}
- std::vector y = {0.641842f, 0.667534f, 0.339592f, 0.480609f, 0.537525f, 0.340368f, 0.752882f, 0.387601f, 0.686814f, 0.643437f, 0.324983f, 0.468788f, 0.539061f, 0.319610f, 0.754181f, 0.373093f, 0.702380f, 0.693136f, 0.318406f, 0.456714f, 0.540838f, 0.315487f, 0.718291f, 0.311025f, 0.681769f, 0.670603f, 0.329705f, 0.456661f, 0.573902f, 0.337385f, 0.700597f, 0.333385f, 0.508992f, 0.253478f, 0.553979f, 0.466355f, 0.398637f, 0.412493f, 0.495810f, 0.677675f, 0.521609f, 0.278997f, 0.564189f, 0.434417f, 0.448085f, 0.467205f, 0.567856f, 0.664713f, 0.490146f, 0.261321f, 0.560582f, 0.424598f, 0.450318f, 0.467336f, 0.520983f, 0.720798f, 0.516095f, 0.264495f, 0.577940f, 0.475340f, 0.444145f, 0.477909f, 0.485663f, 0.672846f, 0.499389f, 0.402198f, 0.520218f, 0.550550f, 0.481065f, 0.730488f, 0.492535f, 0.392315f, 0.436722f, 0.398514f, 0.497457f, 0.502270f, 0.520993f, 0.730472f, 0.565429f, 0.380282f, 0.461226f, 0.392968f, 0.536035f, 0.505191f, 0.446570f, 0.751253f, 0.478584f, 0.389036f, 0.423738f, 0.443828f, 0.554323f, 0.462607f, 0.476656f, 0.733228f, 0.482219f, 0.411910f, 0.620556f, 0.662948f, 0.349409f, 0.482541f, 0.537250f, 0.351544f, 0.734285f, 0.397172f, 0.689500f, 0.637077f, 0.320710f, 0.470914f, 0.526307f, 0.312878f, 0.775762f, 0.384457f, 0.696615f, 0.681034f, 0.324383f, 0.459632f, 0.539497f, 0.317950f, 0.709736f, 0.320698f, 0.671696f, 0.676830f, 0.332387f, 0.453234f, 0.578648f, 0.345084f, 0.685369f, 0.328092f, 0.520830f, 0.251061f, 0.562824f, 0.469184f, 0.393635f, 0.405203f, 0.493565f, 0.668713f, 0.541328f, 0.282797f, 0.577903f, 0.434065f, 0.444664f, 0.460403f, 0.572628f, 0.646402f, 0.493508f, 0.265246f, 0.572078f, 0.418658f, 0.464491f, 0.483746f, 0.516536f, 0.724847f, 0.503705f, 0.270557f, 0.577678f, 0.465114f, 0.468430f, 0.508402f, 0.489087f, 0.689442f, 0.500042f, 0.410507f, 0.521381f, 0.553244f, 0.459062f, 0.719706f, 0.476571f, 0.395052f, 0.429926f, 0.408857f, 0.507006f, 0.493937f, 0.529878f, 0.728873f, 0.571495f, 0.376256f, 0.453676f, 0.380482f, 0.526100f, 0.496696f, 0.457383f, 0.761933f, 0.486657f, 0.396608f, 0.435748f, 0.432822f, 0.531763f, 0.482255f, 0.477046f, 0.726381f, 0.487480f, 0.416572f, 0.626676f, 0.683736f, 0.340657f, 0.475002f, 0.549981f, 0.353311f, 0.740157f, 0.378827f, 0.681403f, 0.636622f, 0.324593f, 0.469088f, 0.537323f, 0.321344f, 0.762506f, 0.384239f, 0.693108f, 0.683351f, 0.329873f, 0.460504f, 0.555115f, 0.325379f, 0.694659f, 0.316422f, 0.677285f, 0.670298f, 0.329724f, 0.456327f, 0.567533f, 0.337560f, 0.701396f, 0.336191f, 0.515940f, 0.251020f, 0.562035f, 0.442479f, 0.405802f, 0.410828f, 0.519841f, 0.686781f, 0.522057f, 0.285013f, 0.562761f, 0.453472f, 0.451971f, 0.481286f, 0.558322f, 0.649971f, 0.486787f, 0.258011f, 0.557963f, 0.426743f, 0.442028f, 0.457034f, 0.510534f, 0.724945f, 0.498901f, 0.272090f, 0.572650f, 0.467930f, 0.465335f, 0.506181f, 0.484559f, 0.690090f, 0.499525f, 0.398443f, 0.522291f, 0.550620f, 0.465209f, 0.731897f, 0.484389f, 0.388997f, 0.411109f, 0.420719f, 0.523354f, 0.478677f, 0.522513f, 0.723052f, 0.587358f, 0.350775f, 0.450881f, 0.384685f, 0.527140f, 0.502089f, 0.438660f, 0.749234f, 0.493312f, 0.377459f, 0.425945f, 0.432397f, 0.544111f, 0.466484f, 0.488077f, 0.738712f, 0.493642f, 0.412262f, 0.565934f, 0.795554f, 0.527262f, 0.295395f, 0.394937f, 0.326235f, 0.457519f, 0.454071f, 0.511390f, 0.753500f, 0.500815f, 0.303925f, 0.403792f, 0.343750f, 0.516333f, 0.463035f, 0.491925f, 0.753119f, 0.503555f, 0.310489f, 0.373396f, 0.334562f, 0.526486f, 0.470500f, 0.495985f, 0.733211f, 0.532951f, 0.342292f, 0.346065f, 0.355272f, 0.479542f, 0.509107f, 0.379088f, 0.582413f, 0.414383f, 0.571800f, 0.613176f, 0.687631f, 0.185596f, 0.656867f, 0.390452f, 0.532452f, 0.407547f, 0.564799f, 0.606499f, 0.653258f, 0.176547f, 0.698038f, 0.410398f, 0.604586f, 0.442972f, 0.497533f, 0.595085f, 0.732265f, 0.187201f, 0.663169f, 0.448716f, 0.590302f, 0.411879f, 0.518449f, 0.636722f, 0.695827f, 0.154292f, 0.666828f, 0.458054f, 0.608582f, 0.430376f, 0.316371f, 0.547620f, 0.542559f, 0.542043f, 0.556297f, 0.468371f, 0.559154f, 0.465195f, 0.344099f, 0.482571f, 0.527115f, 0.527529f, 0.616254f, 0.494566f, 0.605555f, 0.432360f, 0.382197f, 0.466678f, 0.556031f, 0.459313f, 0.588575f, 0.532798f, 0.597684f, 0.412305f, 0.393400f, 0.462773f, 0.491821f, 0.483189f, 0.593919f, 0.569241f, 0.793791f, 0.532988f, 0.300026f, 0.393843f, 0.327085f, 0.448199f, 0.457416f, 0.493302f, 0.725336f, 0.512066f, 0.327500f, 0.404238f, 0.351704f, 0.507818f, 0.477990f, 0.479548f, 0.756083f, 0.511730f, 0.309729f, 0.366024f, 0.338031f, 0.503335f, 0.472352f, 0.473026f, 0.696816f, 0.543129f, 0.374608f, 0.335432f, 0.360978f, 0.486364f, 0.531799f, 0.380422f, 0.599984f, 0.413640f, 0.564090f, 0.607571f, 0.708289f, 0.187551f, 0.671587f, 0.381058f, 0.550543f, 0.422336f, 0.556663f, 0.599418f, 0.666369f, 0.182365f, 0.678737f, 0.423800f, 0.600509f, 0.437094f, 0.494968f, 0.603340f, 0.727226f, 0.179659f, 0.667114f, 0.464399f, 0.563292f, 0.399716f, 0.529198f, 0.655782f, 0.666396f, 0.143497f, 0.659062f, 0.453034f, 0.596627f, 0.417365f, 0.314318f, 0.554269f, 0.518967f, 0.550250f, 0.556252f, 0.494918f, 0.587774f, 0.467566f, 0.350222f, 0.481994f, 0.538857f, 0.525631f, 0.605359f, 0.497486f, 0.608472f, 0.429145f, 0.384532f, 0.466790f, 0.554752f, 0.457698f, 0.586510f, 0.548577f, 0.604359f, 0.398097f, 0.414429f, 0.448200f, 0.485158f, 0.461395f, 0.593015f, 0.563470f, 0.796184f, 0.532783f, 0.293209f, 0.408910f, 0.327450f, 0.438028f, 0.447011f, 0.493041f, 0.739603f, 0.496957f, 0.311881f, 0.389768f, 0.352503f, 0.530113f, 0.476738f, 0.484897f, 0.752985f, 0.511921f, 0.312174f, 0.370408f, 0.339775f, 0.504061f, 0.473793f, 0.487978f, 0.714687f, 0.538817f, 0.358426f, 0.348908f, 0.355820f, 0.481380f, 0.516214f, 0.370872f, 0.602034f, 0.400225f, 0.611090f, 0.630508f, 0.662527f, 0.162489f, 0.658299f, 0.378734f, 0.537283f, 0.412214f, 0.570032f, 0.601452f, 0.653569f, 0.179932f, 0.693105f, 0.411981f, 0.605715f, 0.448022f, 0.481469f, 0.585099f, 0.748463f, 0.195177f, 0.671915f, 0.442141f, 0.581881f, 0.393362f, 0.555388f, 0.650764f, 0.665937f, 0.141141f, 0.675100f, 0.448606f, 0.605061f, 0.412183f, 0.312673f, 0.559178f, 0.530440f, 0.538275f, 0.546820f, 0.494936f, 0.585982f, 0.469875f, 0.355291f, 0.474437f, 0.542980f, 0.518181f, 0.609491f, 0.522046f, 0.618936f, 0.412090f, 0.410711f, 0.452217f, 0.540284f, 0.444109f, 0.585510f, 0.570158f, 0.614413f, 0.415425f, 0.410005f, 0.441791f, 0.491080f, 0.466021f, 0.595833f};
+ std::vector y = {0.641842f, 0.667534f, 0.339592f, 0.480609f, 0.537525f, 0.340368f, 0.752882f, 0.387601f, 0.686814f, 0.643437f, 0.324983f, 0.468788f, 0.539061f, 0.319610f, 0.754181f, 0.373093f, 0.702380f, 0.693136f, 0.318406f, 0.456714f, 0.540838f, 0.315487f, 0.718291f, 0.311025f, 0.681769f, 0.670603f, 0.329705f, 0.456661f, 0.573902f, 0.337385f, 0.700597f, 0.333385f, 0.644472f, 0.666279f, 0.336558f, 0.478260f, 0.534820f, 0.338286f, 0.756443f, 0.387184f, 0.674255f, 0.645509f, 0.327427f, 0.465534f, 0.543598f, 0.328256f, 0.743604f, 0.373978f, 0.689753f, 0.687485f, 0.332246f, 0.457085f, 0.565540f, 0.331625f, 0.677863f, 0.308191f, 0.663033f, 0.669169f, 0.333832f, 0.452516f, 0.576569f, 0.348823f, 0.685447f, 0.338196f, 0.613061f, 0.681689f, 0.345384f, 0.474784f, 0.541609f, 0.357958f, 0.728217f, 0.383408f, 0.680108f, 0.637886f, 0.329455f, 0.469504f, 0.544973f, 0.325193f, 0.745572f, 0.378169f, 0.695405f, 0.687321f, 0.323229f, 0.456101f, 0.553544f, 0.323743f, 0.706057f, 0.314785f, 0.672814f, 0.678842f, 0.323628f, 0.449345f, 0.572724f, 0.342071f, 0.707722f, 0.332714f, 0.512254f, 0.252087f, 0.555774f, 0.456582f, 0.393340f, 0.400567f, 0.501655f, 0.680466f, 0.530775f, 0.288611f, 0.570275f, 0.444357f, 0.454871f, 0.480588f, 0.567893f, 0.645871f, 0.491847f, 0.262209f, 0.561930f, 0.418081f, 0.444398f, 0.456345f, 0.519658f, 0.722565f, 0.523232f, 0.267034f, 0.591659f, 0.459565f, 0.462164f, 0.494775f, 0.497558f, 0.678628f, 0.520830f, 0.251061f, 0.562824f, 0.469184f, 0.393635f, 0.405203f, 0.493565f, 0.668713f, 0.541328f, 0.282797f, 0.577903f, 0.434065f, 0.444664f, 0.460403f, 0.572628f, 0.646402f, 0.493508f, 0.265246f, 0.572078f, 0.418658f, 0.464491f, 0.483746f, 0.516536f, 0.724847f, 0.503705f, 0.270557f, 0.577678f, 0.465114f, 0.468430f, 0.508402f, 0.489087f, 0.689442f, 0.513034f, 0.252153f, 0.561841f, 0.455825f, 0.411518f, 0.424734f, 0.508095f, 0.683202f, 0.537475f, 0.278680f, 0.572605f, 0.449901f, 0.433722f, 0.452424f, 0.554372f, 0.643199f, 0.503808f, 0.259719f, 0.571011f, 0.415224f, 0.442363f, 0.450636f, 0.525191f, 0.716156f, 0.524579f, 0.263175f, 0.588806f, 0.462952f, 0.450874f, 0.480435f, 0.495070f, 0.675950f, 0.503113f, 0.409947f, 0.538941f, 0.550010f, 0.457564f, 0.729741f, 0.472483f, 0.384586f, 0.421666f, 0.416784f, 0.522405f, 0.484472f, 0.519795f, 0.728113f, 0.570887f, 0.363251f, 0.462182f, 0.372738f, 0.510951f, 0.511798f, 0.446353f, 0.754695f, 0.485592f, 0.397135f, 0.421437f, 0.447040f, 0.546262f, 0.462919f, 0.473860f, 0.726421f, 0.479062f, 0.420641f, 0.498228f, 0.402912f, 0.524895f, 0.548811f, 0.462668f, 0.729601f, 0.480759f, 0.390396f, 0.421638f, 0.418506f, 0.518644f, 0.484993f, 0.512452f, 0.724489f, 0.562537f, 0.370564f, 0.461864f, 0.376424f, 0.511195f, 0.510163f, 0.461531f, 0.755198f, 0.491549f, 0.400847f, 0.425338f, 0.456035f, 0.553542f, 0.466468f, 0.482400f, 0.722062f, 0.483532f, 0.415135f, 0.499525f, 0.398443f, 0.522291f, 0.550620f, 0.465209f, 0.731897f, 0.484389f, 0.388997f, 0.411109f, 0.420719f, 0.523354f, 0.478677f, 0.522513f, 0.723052f, 0.587358f, 0.350775f, 0.450881f, 0.384685f, 0.527140f, 0.502089f, 0.438660f, 0.749234f, 0.493312f, 0.377459f, 0.425945f, 0.432397f, 0.544111f, 0.466484f, 0.488077f, 0.738712f, 0.493642f, 0.412262f, 0.565934f, 0.795554f, 0.527262f, 0.295395f, 0.394937f, 0.326235f, 0.457519f, 0.454071f, 0.511390f, 0.753500f, 0.500815f, 0.303925f, 0.403792f, 0.343750f, 0.516333f, 0.463035f, 0.491925f, 0.753119f, 0.503555f, 0.310489f, 0.373396f, 0.334562f, 0.526486f, 0.470500f, 0.495985f, 0.733211f, 0.532951f, 0.342292f, 0.346065f, 0.355272f, 0.479542f, 0.509107f, 0.560779f, 0.795626f, 0.527843f, 0.292198f, 0.403399f, 0.328103f, 0.449548f, 0.449270f, 0.492632f, 0.741337f, 0.501964f, 0.308729f, 0.404425f, 0.353946f, 0.510715f, 0.469292f, 0.498506f, 0.749246f, 0.510938f, 0.317603f, 0.377607f, 0.333171f, 0.516589f, 0.472113f, 0.494030f, 0.738331f, 0.525273f, 0.334388f, 0.351797f, 0.349013f, 0.492978f, 0.499192f, 0.558701f, 0.785575f, 0.541472f, 0.309741f, 0.379566f, 0.336180f, 0.433460f, 0.471779f, 0.500494f, 0.748997f, 0.495158f, 0.302537f, 0.401868f, 0.348977f, 0.525071f, 0.465493f, 0.496427f, 0.763380f, 0.504640f, 0.303037f, 0.375539f, 0.332025f, 0.517142f, 0.464096f, 0.466789f, 0.731320f, 0.529262f, 0.338950f, 0.329005f, 0.361720f, 0.481664f, 0.514476f, 0.356477f, 0.623874f, 0.420893f, 0.592125f, 0.610336f, 0.687956f, 0.174269f, 0.652548f, 0.366057f, 0.567382f, 0.428770f, 0.553226f, 0.582617f, 0.683498f, 0.188604f, 0.695704f, 0.406930f, 0.625170f, 0.441775f, 0.499327f, 0.590722f, 0.740689f, 0.180721f, 0.681143f, 0.430954f, 0.584531f, 0.412720f, 0.532459f, 0.630830f, 0.690216f, 0.161882f, 0.663851f, 0.380422f, 0.599984f, 0.413640f, 0.564090f, 0.607571f, 0.708289f, 0.187551f, 0.671587f, 0.381058f, 0.550543f, 0.422336f, 0.556663f, 0.599418f, 0.666369f, 0.182365f, 0.678737f, 0.423800f, 0.600509f, 0.437094f, 0.494968f, 0.603340f, 0.727226f, 0.179659f, 0.667114f, 0.464399f, 0.563292f, 0.399716f, 0.529198f, 0.655782f, 0.666396f, 0.143497f, 0.659062f, 0.365268f, 0.611770f, 0.413907f, 0.600775f, 0.622849f, 0.667798f, 0.164152f, 0.647839f, 0.377540f, 0.543255f, 0.401769f, 0.588162f, 0.610896f, 0.645976f, 0.172500f, 0.695675f, 0.428349f, 0.590245f, 0.429343f, 0.497694f, 0.606978f, 0.727059f, 0.182826f, 0.671502f, 0.466759f, 0.580932f, 0.396764f, 0.527984f, 0.655065f, 0.677027f, 0.138356f, 0.672848f, 0.431113f, 0.593599f, 0.391529f, 0.327778f, 0.551802f, 0.526872f, 0.512055f, 0.547473f, 0.461591f, 0.564565f, 0.469932f, 0.335454f, 0.493299f, 0.536959f, 0.537769f, 0.611109f, 0.505296f, 0.606927f, 0.414343f, 0.395585f, 0.462205f, 0.538029f, 0.450814f, 0.585742f, 0.550355f, 0.606479f, 0.419783f, 0.396625f, 0.449703f, 0.500831f, 0.464506f, 0.594653f, 0.460993f, 0.609826f, 0.424563f, 0.322395f, 0.546231f, 0.537700f, 0.541169f, 0.555672f, 0.479953f, 0.573210f, 0.449011f, 0.356276f, 0.482535f, 0.523785f, 0.516393f, 0.605958f, 0.473948f, 0.587667f, 0.412118f, 0.378344f, 0.472903f, 0.540161f, 0.445341f, 0.585184f, 0.561693f, 0.609513f, 0.394200f, 0.418769f, 0.444939f, 0.478136f, 0.458334f, 0.591187f, 0.448606f, 0.605061f, 0.412183f, 0.312673f, 0.559178f, 0.530440f, 0.538275f, 0.546820f, 0.494936f, 0.585982f, 0.469875f, 0.355291f, 0.474437f, 0.542980f, 0.518181f, 0.609491f, 0.522046f, 0.618936f, 0.412090f, 0.410711f, 0.452217f, 0.540284f, 0.444109f, 0.585510f, 0.570158f, 0.614413f, 0.415425f, 0.410005f, 0.441791f, 0.491080f, 0.466021f, 0.595833f};
ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size);
ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size);
@@ -886,7 +919,7 @@ TEST(AttentionTest, Attention4DGqaWithPastAndPresent) {
// {2, 3, 12, 8}
std::vector past_value = {0.546950f, 0.293617f, 0.968204f, 0.226196f, 0.015738f, 0.325855f, 0.502509f, 0.028363f, 0.559248f, 0.874283f, 0.704732f, 0.622968f, 0.955962f, 0.958279f, 0.824266f, 0.607742f, 0.487765f, 0.013316f, 0.606262f, 0.989088f, 0.818101f, 0.340605f, 0.152047f, 0.784059f, 0.743938f, 0.967047f, 0.874842f, 0.555663f, 0.101284f, 0.483501f, 0.313695f, 0.512408f, 0.301702f, 0.861823f, 0.844327f, 0.315465f, 0.599581f, 0.430181f, 0.909093f, 0.187361f, 0.697728f, 0.970375f, 0.175276f, 0.201966f, 0.693723f, 0.779154f, 0.490549f, 0.609686f, 0.212682f, 0.476614f, 0.112072f, 0.321422f, 0.284780f, 0.444625f, 0.930126f, 0.181268f, 0.401388f, 0.615597f, 0.946557f, 0.133148f, 0.917877f, 0.081054f, 0.480741f, 0.454590f, 0.209603f, 0.347460f, 0.454165f, 0.865211f, 0.955064f, 0.518926f, 0.870100f, 0.608172f, 0.349087f, 0.194194f, 0.413135f, 0.522824f, 0.044443f, 0.145841f, 0.600184f, 0.225002f, 0.837326f, 0.326942f, 0.104834f, 0.083531f, 0.937123f, 0.118020f, 0.140910f, 0.862666f, 0.254288f, 0.665951f, 0.816726f, 0.607181f, 0.957489f, 0.708883f, 0.112752f, 0.558410f, 0.718186f, 0.801957f, 0.026321f, 0.718879f, 0.825681f, 0.746834f, 0.512349f, 0.458021f, 0.549419f, 0.704644f, 0.922914f, 0.617035f, 0.887834f, 0.701257f, 0.068336f, 0.500828f, 0.286486f, 0.285175f, 0.355928f, 0.314733f, 0.578610f, 0.683601f, 0.268749f, 0.129763f, 0.058809f, 0.575753f, 0.186130f, 0.009248f, 0.927753f, 0.537140f, 0.092448f, 0.842921f, 0.983203f, 0.448601f, 0.042490f, 0.117546f, 0.381654f, 0.885523f, 0.148039f, 0.823990f, 0.014976f, 0.457389f, 0.644397f, 0.060379f, 0.614763f, 0.944404f, 0.160260f, 0.729611f, 0.609094f, 0.185116f, 0.006203f, 0.009284f, 0.532092f, 0.942779f, 0.644299f, 0.714300f, 0.493865f, 0.581889f, 0.126368f, 0.876821f, 0.760793f, 0.998199f, 0.297723f, 0.227018f, 0.125162f, 0.964210f, 0.780885f, 0.166325f, 0.552686f, 0.413768f, 0.151486f, 0.162073f, 0.963470f, 0.304964f, 0.941439f, 0.075611f, 0.460803f, 0.129619f, 0.004787f, 0.553766f, 0.113894f, 0.722025f, 0.698116f, 0.176333f, 0.941742f, 0.721043f, 0.297970f, 0.709234f, 0.731930f, 0.342226f, 0.375589f, 0.359107f, 0.616618f, 0.900410f, 0.173193f, 0.875200f, 0.027653f, 0.660339f, 0.414439f, 0.791282f, 0.721198f, 0.480108f, 0.643864f, 0.501773f, 0.811518f, 0.476084f, 0.523156f, 0.250521f, 0.605043f, 0.302905f, 0.577284f, 0.169678f, 0.159469f, 0.417030f, 0.426820f, 0.268109f, 0.131597f, 0.039211f, 0.025232f, 0.271550f, 0.461853f, 0.726243f, 0.474872f, 0.904051f, 0.035220f, 0.180661f, 0.338515f, 0.577496f, 0.852736f, 0.350202f, 0.267989f, 0.061889f, 0.821303f, 0.379666f, 0.571550f, 0.983555f, 0.001595f, 0.145450f, 0.779111f, 0.805128f, 0.769247f, 0.536999f, 0.978857f, 0.396185f, 0.601944f, 0.063369f, 0.409857f, 0.722500f, 0.238739f, 0.943828f, 0.686783f, 0.287575f, 0.768999f, 0.083165f, 0.974774f, 0.049285f, 0.933456f, 0.252854f, 0.757824f, 0.000074f, 0.254240f, 0.749101f, 0.532336f, 0.114952f, 0.393630f, 0.375549f, 0.568162f, 0.667977f, 0.840830f, 0.497231f, 0.392022f, 0.143977f, 0.804823f, 0.713370f, 0.408677f, 0.518432f, 0.665183f, 0.164806f, 0.027198f, 0.317504f, 0.595585f, 0.486606f, 0.692555f, 0.819690f, 0.488442f, 0.134267f, 0.850628f, 0.574990f, 0.739937f, 0.704665f, 0.968212f, 0.295307f, 0.705307f, 0.365676f, 0.395411f, 0.230595f, 0.344010f, 0.948297f, 0.292571f, 0.245991f, 0.583138f, 0.258036f, 0.473386f, 0.834176f, 0.230400f, 0.426691f, 0.610490f, 0.545629f, 0.974723f, 0.680370f, 0.739946f, 0.966956f, 0.414438f, 0.355380f, 0.043862f, 0.184204f, 0.237190f, 0.183504f, 0.754784f, 0.535883f, 0.667634f, 0.820462f, 0.230774f, 0.325924f, 0.708360f, 0.392759f, 0.029271f, 0.434955f, 0.908273f, 0.409021f, 0.332249f, 0.989525f, 0.644416f, 0.365998f, 0.102020f, 0.787849f, 0.708075f, 0.921916f, 0.217276f, 0.114924f, 0.724073f, 0.203396f, 0.176104f, 0.319807f, 0.816825f, 0.539537f, 0.045850f, 0.463895f, 0.683980f, 0.538368f, 0.572450f, 0.224777f, 0.847739f, 0.561399f, 0.713246f, 0.981864f, 0.428199f, 0.881067f, 0.007281f, 0.033407f, 0.590280f, 0.311449f, 0.248277f, 0.277935f, 0.318403f, 0.728948f, 0.569196f, 0.789036f, 0.830197f, 0.842935f, 0.414644f, 0.421273f, 0.926266f, 0.661764f, 0.080467f, 0.542187f, 0.356007f, 0.987435f, 0.013655f, 0.612181f, 0.723623f, 0.288907f, 0.973642f, 0.859537f, 0.915653f, 0.019232f, 0.569872f, 0.294650f, 0.849029f, 0.632850f, 0.538877f, 0.114588f, 0.540223f, 0.631904f, 0.955912f, 0.585051f, 0.967401f, 0.961606f, 0.650200f, 0.505908f, 0.466022f, 0.890379f, 0.028257f, 0.113808f, 0.102072f, 0.756935f, 0.339651f, 0.637969f, 0.603783f, 0.385828f, 0.531568f, 0.645139f, 0.940950f, 0.575634f, 0.614367f, 0.067856f, 0.952216f, 0.528082f, 0.801273f, 0.050291f, 0.420910f, 0.256975f, 0.266976f, 0.791454f, 0.623867f, 0.439745f, 0.010586f, 0.964928f, 0.962023f, 0.217552f, 0.041346f, 0.530199f, 0.951411f, 0.910396f, 0.584663f, 0.303549f, 0.329961f, 0.897914f, 0.491784f, 0.131116f, 0.248425f, 0.276795f, 0.123547f, 0.463044f, 0.916051f, 0.668783f, 0.072474f, 0.005495f, 0.276248f, 0.362693f, 0.776750f, 0.967006f, 0.387567f, 0.686690f, 0.994902f, 0.745667f, 0.636190f, 0.078075f, 0.323215f, 0.913392f, 0.201005f, 0.843590f, 0.696324f, 0.366324f, 0.529174f, 0.542806f, 0.714054f, 0.516556f, 0.133076f, 0.773455f, 0.406273f, 0.963094f, 0.283514f, 0.263079f, 0.333507f, 0.572317f, 0.894870f, 0.176282f, 0.279679f, 0.581680f, 0.454334f, 0.447323f, 0.820734f, 0.923878f, 0.481307f, 0.687352f, 0.801059f, 0.518366f, 0.294316f, 0.638085f, 0.585109f, 0.901563f, 0.052407f, 0.910131f, 0.534432f, 0.015676f, 0.344702f, 0.724334f, 0.488433f, 0.980159f, 0.422610f, 0.326635f, 0.821672f, 0.547907f, 0.682327f, 0.805702f, 0.671428f, 0.422408f, 0.124796f, 0.580248f, 0.897433f, 0.418892f, 0.910725f, 0.503528f, 0.620842f, 0.832989f, 0.564597f, 0.090969f, 0.980979f, 0.245849f, 0.710505f, 0.505113f, 0.478773f, 0.243941f, 0.722151f, 0.112788f, 0.990453f, 0.845374f, 0.534509f, 0.424553f, 0.286465f, 0.501591f, 0.879417f, 0.275006f, 0.500537f, 0.234550f, 0.337149f, 0.190261f, 0.990539f, 0.571497f, 0.732815f, 0.098250f, 0.366118f, 0.892640f, 0.084438f, 0.165483f, 0.625418f, 0.622789f, 0.838227f, 0.935493f, 0.141986f, 0.259374f, 0.427461f, 0.000903f, 0.069814f, 0.226491f, 0.481102f, 0.251523f, 0.876682f, 0.324273f, 0.924623f, 0.974787f, 0.449862f, 0.227129f, 0.291666f, 0.776334f, 0.273350f, 0.380583f, 0.478576f, 0.575111f, 0.996100f, 0.232210f, 0.353424f, 0.262891f, 0.361113f, 0.100805f, 0.359810f, 0.887865f, 0.298590f, 0.371935f};
// {2, 9, 4, 8}
- std::vector y = {0.544462f, 0.617844f, 0.506335f, 0.473482f, 0.606855f, 0.423464f, 0.544771f, 0.450451f, 0.524249f, 0.627160f, 0.497201f, 0.440288f, 0.619110f, 0.437084f, 0.563680f, 0.440037f, 0.516736f, 0.577726f, 0.523888f, 0.493471f, 0.594122f, 0.433401f, 0.585942f, 0.457686f, 0.528512f, 0.604578f, 0.472106f, 0.471486f, 0.600445f, 0.446256f, 0.622393f, 0.435442f, 0.440810f, 0.437705f, 0.476508f, 0.320820f, 0.605191f, 0.640150f, 0.306216f, 0.610947f, 0.485794f, 0.448216f, 0.485639f, 0.323744f, 0.594446f, 0.646597f, 0.321742f, 0.605751f, 0.501858f, 0.445502f, 0.487899f, 0.384660f, 0.597134f, 0.616430f, 0.331401f, 0.566459f, 0.502522f, 0.409965f, 0.526639f, 0.348601f, 0.565200f, 0.586558f, 0.325044f, 0.603422f, 0.450250f, 0.368009f, 0.550911f, 0.460338f, 0.523907f, 0.508816f, 0.575624f, 0.426601f, 0.472310f, 0.372844f, 0.517852f, 0.431688f, 0.551555f, 0.527657f, 0.600578f, 0.473069f, 0.456633f, 0.442035f, 0.539875f, 0.437863f, 0.540202f, 0.499608f, 0.556470f, 0.419831f, 0.463081f, 0.416724f, 0.526389f, 0.458654f, 0.540120f, 0.551554f, 0.569399f, 0.447102f, 0.534296f, 0.597655f, 0.509699f, 0.487167f, 0.607438f, 0.426383f, 0.522794f, 0.458435f, 0.510147f, 0.622761f, 0.501724f, 0.453386f, 0.629671f, 0.434103f, 0.582477f, 0.437681f, 0.520031f, 0.568543f, 0.525216f, 0.490370f, 0.571745f, 0.428629f, 0.572995f, 0.460086f, 0.533607f, 0.614962f, 0.474130f, 0.456345f, 0.576467f, 0.448127f, 0.599211f, 0.432252f, 0.447842f, 0.430169f, 0.480055f, 0.320521f, 0.590915f, 0.627003f, 0.314551f, 0.609320f, 0.499216f, 0.438828f, 0.485519f, 0.322134f, 0.586364f, 0.645824f, 0.326481f, 0.596989f, 0.496362f, 0.442741f, 0.492120f, 0.366111f, 0.601604f, 0.615566f, 0.326354f, 0.567173f, 0.496946f, 0.422179f, 0.533144f, 0.342588f, 0.590482f, 0.605923f, 0.318055f, 0.610401f, 0.452598f, 0.361594f, 0.550919f, 0.455099f, 0.530404f, 0.519313f, 0.588655f, 0.431890f, 0.464325f, 0.389636f, 0.515359f, 0.429087f, 0.540767f, 0.518376f, 0.586627f, 0.471074f, 0.458527f, 0.422216f, 0.537762f, 0.434123f, 0.550956f, 0.507704f, 0.564828f, 0.421548f, 0.463044f, 0.407985f, 0.523093f, 0.473684f, 0.542663f, 0.551348f, 0.576783f, 0.448743f, 0.546208f, 0.621128f, 0.501647f, 0.468191f, 0.612298f, 0.425183f, 0.549241f, 0.447622f, 0.519355f, 0.619636f, 0.487775f, 0.444259f, 0.625749f, 0.430264f, 0.584338f, 0.436887f, 0.521021f, 0.572716f, 0.522539f, 0.486440f, 0.581317f, 0.429079f, 0.579691f, 0.455426f, 0.526431f, 0.604615f, 0.476481f, 0.469814f, 0.588766f, 0.445640f, 0.609160f, 0.437785f, 0.443498f, 0.439338f, 0.487424f, 0.310942f, 0.607341f, 0.630362f, 0.312591f, 0.621999f, 0.483917f, 0.446308f, 0.477454f, 0.331028f, 0.592608f, 0.653297f, 0.322368f, 0.599377f, 0.497354f, 0.443447f, 0.477781f, 0.384002f, 0.591587f, 0.610287f, 0.328537f, 0.567630f, 0.499369f, 0.421961f, 0.536492f, 0.345379f, 0.586450f, 0.600541f, 0.312965f, 0.609437f, 0.451750f, 0.359685f, 0.553321f, 0.464992f, 0.524025f, 0.522507f, 0.582135f, 0.425124f, 0.459696f, 0.394679f, 0.519051f, 0.411226f, 0.539772f, 0.505003f, 0.587681f, 0.469383f, 0.451681f, 0.430062f, 0.541843f, 0.420929f, 0.542240f, 0.487570f, 0.567067f, 0.419708f, 0.456288f, 0.412096f, 0.527592f, 0.467870f, 0.545021f, 0.547842f, 0.573135f, 0.448166f, 0.581220f, 0.559255f, 0.469802f, 0.489935f, 0.557197f, 0.487135f, 0.377325f, 0.425637f, 0.582374f, 0.560738f, 0.425382f, 0.463129f, 0.549939f, 0.481810f, 0.350432f, 0.466049f, 0.593554f, 0.542315f, 0.482597f, 0.496969f, 0.518851f, 0.507807f, 0.366054f, 0.457476f, 0.569468f, 0.565965f, 0.444765f, 0.465404f, 0.515500f, 0.520271f, 0.337845f, 0.448357f, 0.557802f, 0.585925f, 0.426858f, 0.464044f, 0.585251f, 0.557395f, 0.433327f, 0.615342f, 0.534368f, 0.573723f, 0.426393f, 0.518102f, 0.586735f, 0.513129f, 0.371969f, 0.636735f, 0.544166f, 0.588469f, 0.433470f, 0.481894f, 0.595019f, 0.533156f, 0.396519f, 0.608115f, 0.547125f, 0.604473f, 0.441984f, 0.469765f, 0.599107f, 0.561685f, 0.347618f, 0.563457f, 0.507550f, 0.485293f, 0.545846f, 0.408434f, 0.482538f, 0.532314f, 0.498883f, 0.525126f, 0.514603f, 0.471457f, 0.539705f, 0.362410f, 0.490158f, 0.513690f, 0.494170f, 0.496909f, 0.492936f, 0.506153f, 0.565865f, 0.364727f, 0.508899f, 0.516217f, 0.558362f, 0.556920f, 0.530472f, 0.521715f, 0.554673f, 0.363830f, 0.509086f, 0.511590f, 0.552396f, 0.541486f, 0.572145f, 0.551531f, 0.471964f, 0.485188f, 0.555030f, 0.493247f, 0.376875f, 0.429387f, 0.580540f, 0.550944f, 0.435664f, 0.480675f, 0.544997f, 0.488698f, 0.344985f, 0.464878f, 0.593774f, 0.541202f, 0.484834f, 0.497316f, 0.509364f, 0.500045f, 0.357235f, 0.448933f, 0.565242f, 0.546653f, 0.459790f, 0.481954f, 0.514950f, 0.516297f, 0.344285f, 0.454476f, 0.548036f, 0.577907f, 0.427075f, 0.478978f, 0.581563f, 0.553606f, 0.426476f, 0.638442f, 0.498925f, 0.598346f, 0.444106f, 0.536998f, 0.575948f, 0.499260f, 0.371120f, 0.626981f, 0.545949f, 0.586548f, 0.428254f, 0.479753f, 0.596943f, 0.527697f, 0.401418f, 0.613028f, 0.542355f, 0.607063f, 0.447840f, 0.467102f, 0.603496f, 0.549575f, 0.364370f, 0.561534f, 0.507041f, 0.473640f, 0.547768f, 0.413960f, 0.490513f, 0.534377f, 0.497277f, 0.517772f, 0.531394f, 0.489105f, 0.531671f, 0.369343f, 0.486462f, 0.501787f, 0.494220f, 0.493498f, 0.485968f, 0.510301f, 0.559766f, 0.361474f, 0.507888f, 0.518858f, 0.564300f, 0.561990f, 0.537984f, 0.527982f, 0.539571f, 0.366920f, 0.498313f, 0.505709f, 0.538027f, 0.541246f, 0.585733f, 0.565800f, 0.441346f, 0.476255f, 0.556453f, 0.497693f, 0.363246f, 0.426799f, 0.578484f, 0.556489f, 0.436699f, 0.481177f, 0.549473f, 0.484153f, 0.355910f, 0.462010f, 0.590951f, 0.542803f, 0.470954f, 0.488994f, 0.512707f, 0.511876f, 0.358555f, 0.455953f, 0.559449f, 0.546003f, 0.462900f, 0.471080f, 0.517298f, 0.519225f, 0.345016f, 0.449149f, 0.526624f, 0.606761f, 0.427660f, 0.480775f, 0.577420f, 0.538850f, 0.426959f, 0.625509f, 0.530502f, 0.585784f, 0.432234f, 0.516800f, 0.584937f, 0.514154f, 0.373726f, 0.623740f, 0.550470f, 0.585577f, 0.436483f, 0.474799f, 0.594100f, 0.540052f, 0.402520f, 0.607686f, 0.537556f, 0.609680f, 0.439490f, 0.477886f, 0.602656f, 0.542957f, 0.350394f, 0.574553f, 0.506900f, 0.488792f, 0.539037f, 0.403028f, 0.494093f, 0.534739f, 0.494292f, 0.511628f, 0.528192f, 0.480037f, 0.546429f, 0.375120f, 0.484828f, 0.505006f, 0.495786f, 0.497935f, 0.502174f, 0.514122f, 0.541314f, 0.369540f, 0.493985f, 0.508263f, 0.550415f, 0.556157f, 0.543269f, 0.529970f, 0.562027f, 0.376526f, 0.499704f, 0.508621f, 0.536068f, 0.545993f};
+ std::vector y = {0.544462f, 0.617844f, 0.506335f, 0.473482f, 0.606855f, 0.423464f, 0.544771f, 0.450451f, 0.524249f, 0.627160f, 0.497201f, 0.440288f, 0.619110f, 0.437084f, 0.563680f, 0.440037f, 0.516736f, 0.577726f, 0.523888f, 0.493471f, 0.594122f, 0.433401f, 0.585942f, 0.457686f, 0.528512f, 0.604578f, 0.472106f, 0.471486f, 0.600445f, 0.446256f, 0.622393f, 0.435442f, 0.546090f, 0.618047f, 0.504325f, 0.472246f, 0.609686f, 0.422467f, 0.546964f, 0.451166f, 0.519404f, 0.617868f, 0.491984f, 0.445771f, 0.633094f, 0.436822f, 0.559753f, 0.447209f, 0.519860f, 0.574899f, 0.525759f, 0.489339f, 0.586803f, 0.436452f, 0.577737f, 0.453299f, 0.532473f, 0.609446f, 0.471758f, 0.455772f, 0.573504f, 0.445466f, 0.602573f, 0.433307f, 0.538062f, 0.604199f, 0.500302f, 0.479569f, 0.614174f, 0.429231f, 0.522434f, 0.459369f, 0.528422f, 0.620683f, 0.485333f, 0.435606f, 0.616579f, 0.432233f, 0.565856f, 0.440093f, 0.525356f, 0.580613f, 0.529584f, 0.483095f, 0.583395f, 0.433491f, 0.593043f, 0.451879f, 0.540119f, 0.622995f, 0.472122f, 0.449888f, 0.586202f, 0.447435f, 0.611846f, 0.434879f, 0.449905f, 0.430732f, 0.474834f, 0.321674f, 0.590495f, 0.626300f, 0.319127f, 0.606006f, 0.492763f, 0.445330f, 0.490219f, 0.319940f, 0.588298f, 0.643644f, 0.317760f, 0.596360f, 0.507993f, 0.440004f, 0.490555f, 0.378128f, 0.588227f, 0.604974f, 0.329202f, 0.561987f, 0.511572f, 0.403440f, 0.542761f, 0.331792f, 0.568397f, 0.583366f, 0.333122f, 0.608456f, 0.447842f, 0.430169f, 0.480055f, 0.320521f, 0.590915f, 0.627003f, 0.314551f, 0.609320f, 0.499216f, 0.438828f, 0.485519f, 0.322134f, 0.586364f, 0.645824f, 0.326481f, 0.596989f, 0.496362f, 0.442741f, 0.492120f, 0.366111f, 0.601604f, 0.615566f, 0.326354f, 0.567173f, 0.496946f, 0.422179f, 0.533144f, 0.342588f, 0.590482f, 0.605923f, 0.318055f, 0.610401f, 0.441356f, 0.431701f, 0.488343f, 0.311828f, 0.606159f, 0.632821f, 0.317863f, 0.629084f, 0.495613f, 0.441177f, 0.473223f, 0.335484f, 0.579139f, 0.646878f, 0.321269f, 0.595437f, 0.504999f, 0.443626f, 0.498154f, 0.369326f, 0.588410f, 0.600189f, 0.322347f, 0.562676f, 0.508419f, 0.405342f, 0.533092f, 0.335876f, 0.570568f, 0.589600f, 0.330741f, 0.609168f, 0.456943f, 0.365603f, 0.555030f, 0.454344f, 0.526263f, 0.519062f, 0.578652f, 0.425453f, 0.464039f, 0.391848f, 0.518985f, 0.419419f, 0.541410f, 0.514459f, 0.586459f, 0.470210f, 0.460338f, 0.408599f, 0.539512f, 0.446249f, 0.551945f, 0.511356f, 0.575513f, 0.424325f, 0.452212f, 0.418205f, 0.525148f, 0.459799f, 0.536327f, 0.541881f, 0.571451f, 0.452969f, 0.454154f, 0.354641f, 0.553889f, 0.451027f, 0.536270f, 0.521832f, 0.590756f, 0.429859f, 0.459101f, 0.394962f, 0.512076f, 0.419296f, 0.535702f, 0.516757f, 0.585606f, 0.478117f, 0.458365f, 0.422929f, 0.531943f, 0.447581f, 0.546387f, 0.511705f, 0.564350f, 0.425332f, 0.463274f, 0.429223f, 0.525922f, 0.452328f, 0.539095f, 0.534372f, 0.563738f, 0.449120f, 0.451750f, 0.359685f, 0.553321f, 0.464992f, 0.524025f, 0.522507f, 0.582135f, 0.425124f, 0.459696f, 0.394679f, 0.519051f, 0.411226f, 0.539772f, 0.505003f, 0.587681f, 0.469383f, 0.451681f, 0.430062f, 0.541843f, 0.420929f, 0.542240f, 0.487570f, 0.567067f, 0.419708f, 0.456288f, 0.412096f, 0.527592f, 0.467870f, 0.545021f, 0.547842f, 0.573135f, 0.448166f, 0.581220f, 0.559255f, 0.469802f, 0.489935f, 0.557197f, 0.487135f, 0.377325f, 0.425637f, 0.582374f, 0.560738f, 0.425382f, 0.463129f, 0.549939f, 0.481810f, 0.350432f, 0.466049f, 0.593554f, 0.542315f, 0.482597f, 0.496969f, 0.518851f, 0.507807f, 0.366054f, 0.457476f, 0.569468f, 0.565965f, 0.444765f, 0.465404f, 0.515500f, 0.520271f, 0.337845f, 0.448357f, 0.586343f, 0.566462f, 0.444339f, 0.481474f, 0.557556f, 0.495837f, 0.368487f, 0.425850f, 0.580159f, 0.565990f, 0.400882f, 0.462578f, 0.551037f, 0.497924f, 0.338502f, 0.468483f, 0.592753f, 0.536897f, 0.481975f, 0.489485f, 0.519290f, 0.509298f, 0.366838f, 0.461538f, 0.567139f, 0.559419f, 0.458050f, 0.468739f, 0.514875f, 0.512271f, 0.346335f, 0.449357f, 0.583058f, 0.557532f, 0.454426f, 0.492673f, 0.551748f, 0.496414f, 0.364023f, 0.430048f, 0.579431f, 0.565100f, 0.420761f, 0.466297f, 0.551315f, 0.487418f, 0.348148f, 0.461136f, 0.585687f, 0.535194f, 0.485465f, 0.488622f, 0.513327f, 0.508844f, 0.368049f, 0.455823f, 0.554855f, 0.560589f, 0.456398f, 0.477641f, 0.507017f, 0.518069f, 0.338229f, 0.444624f, 0.500594f, 0.616610f, 0.439949f, 0.495561f, 0.569213f, 0.540425f, 0.422667f, 0.627919f, 0.514283f, 0.584446f, 0.441141f, 0.528331f, 0.577047f, 0.508969f, 0.372295f, 0.646734f, 0.536256f, 0.591823f, 0.428652f, 0.485852f, 0.592863f, 0.525360f, 0.399985f, 0.623408f, 0.552463f, 0.606841f, 0.448560f, 0.466321f, 0.600628f, 0.566464f, 0.356481f, 0.551351f, 0.548036f, 0.577907f, 0.427075f, 0.478978f, 0.581563f, 0.553606f, 0.426476f, 0.638442f, 0.498925f, 0.598346f, 0.444106f, 0.536998f, 0.575948f, 0.499260f, 0.371120f, 0.626981f, 0.545949f, 0.586548f, 0.428254f, 0.479753f, 0.596943f, 0.527697f, 0.401418f, 0.613028f, 0.542355f, 0.607063f, 0.447840f, 0.467102f, 0.603496f, 0.549575f, 0.364370f, 0.561534f, 0.532692f, 0.601573f, 0.425963f, 0.477495f, 0.573122f, 0.544325f, 0.422438f, 0.629794f, 0.512145f, 0.593241f, 0.436187f, 0.532146f, 0.582008f, 0.499410f, 0.366728f, 0.631277f, 0.550263f, 0.590346f, 0.430967f, 0.477189f, 0.600022f, 0.528313f, 0.406504f, 0.603355f, 0.537075f, 0.605495f, 0.437735f, 0.474413f, 0.601068f, 0.542204f, 0.348555f, 0.581430f, 0.499619f, 0.480920f, 0.536032f, 0.413380f, 0.478027f, 0.524393f, 0.490201f, 0.530954f, 0.517442f, 0.475326f, 0.541763f, 0.366450f, 0.498398f, 0.509411f, 0.503732f, 0.490468f, 0.488084f, 0.505941f, 0.554614f, 0.371690f, 0.503635f, 0.510325f, 0.557424f, 0.564303f, 0.534730f, 0.536543f, 0.563296f, 0.362277f, 0.498957f, 0.508357f, 0.538003f, 0.554638f, 0.514150f, 0.481676f, 0.543535f, 0.414778f, 0.478296f, 0.529467f, 0.496600f, 0.522262f, 0.522734f, 0.480361f, 0.534209f, 0.379264f, 0.485836f, 0.500082f, 0.498644f, 0.501901f, 0.474729f, 0.503193f, 0.560206f, 0.362595f, 0.515144f, 0.512647f, 0.557224f, 0.567242f, 0.539217f, 0.533273f, 0.538641f, 0.373064f, 0.495733f, 0.499786f, 0.532998f, 0.547731f, 0.506900f, 0.488792f, 0.539037f, 0.403028f, 0.494093f, 0.534739f, 0.494292f, 0.511628f, 0.528192f, 0.480037f, 0.546429f, 0.375120f, 0.484828f, 0.505006f, 0.495786f, 0.497935f, 0.502174f, 0.514122f, 0.541314f, 0.369540f, 0.493985f, 0.508263f, 0.550415f, 0.556157f, 0.543269f, 0.529970f, 0.562027f, 0.376526f, 0.499704f, 0.508621f, 0.536068f, 0.545993f};
// {2, 3, 18, 8}
std::vector present_key = {0.938256f, 0.244570f, 0.458212f, 0.757407f, 0.203621f, 0.566312f, 0.185817f, 0.104736f, 0.116559f, 0.357639f, 0.004655f, 0.424854f, 0.664197f, 0.401688f, 0.085795f, 0.062689f, 0.278117f, 0.169313f, 0.965095f, 0.151230f, 0.805462f, 0.586108f, 0.569287f, 0.512081f, 0.971763f, 0.363845f, 0.787916f, 0.555294f, 0.395634f, 0.955466f, 0.598316f, 0.118917f, 0.417539f, 0.781582f, 0.693747f, 0.916340f, 0.259377f, 0.758194f, 0.459875f, 0.573610f, 0.955047f, 0.979286f, 0.861591f, 0.359097f, 0.887701f, 0.638609f, 0.429997f, 0.035743f, 0.770128f, 0.502106f, 0.786188f, 0.748023f, 0.793567f, 0.300651f, 0.800799f, 0.548846f, 0.473326f, 0.675126f, 0.021359f, 0.102317f, 0.292177f, 0.982990f, 0.139746f, 0.330596f, 0.051053f, 0.331269f, 0.320326f, 0.946807f, 0.845154f, 0.382764f, 0.024769f, 0.831031f, 0.660536f, 0.152364f, 0.996071f, 0.100233f, 0.867115f, 0.294266f, 0.435353f, 0.795457f, 0.677508f, 0.937864f, 0.621140f, 0.097810f, 0.884360f, 0.769156f, 0.711870f, 0.053734f, 0.396223f, 0.167436f, 0.821904f, 0.700529f, 0.883078f, 0.966575f, 0.774748f, 0.994233f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.614770f, 0.037130f, 0.014252f, 0.342104f, 0.823472f, 0.866135f, 0.960813f, 0.065121f, 0.044571f, 0.913284f, 0.305047f, 0.557987f, 0.982445f, 0.400449f, 0.665871f, 0.400880f, 0.768195f, 0.527715f, 0.237523f, 0.271306f, 0.258059f, 0.532320f, 0.703189f, 0.949280f, 0.694087f, 0.781193f, 0.168926f, 0.374063f, 0.413780f, 0.686380f, 0.295892f, 0.303292f, 0.355889f, 0.810302f, 0.577590f, 0.075277f, 0.078246f, 0.371287f, 0.766591f, 0.688683f, 0.707982f, 0.767210f, 0.287153f, 0.548256f, 0.543353f, 0.739632f, 0.956871f, 0.277990f, 0.793282f, 0.659971f, 0.580238f, 0.774880f, 0.944032f, 0.036691f, 0.147400f, 0.756287f, 0.083791f, 0.516124f, 0.219861f, 0.274296f, 0.701840f, 0.030193f, 0.873319f, 0.444479f, 0.502393f, 0.540048f, 0.645544f, 0.344857f, 0.101107f, 0.318379f, 0.168142f, 0.556133f, 0.318029f, 0.958067f, 0.965734f, 0.620126f, 0.617497f, 0.985379f, 0.887283f, 0.765070f, 0.313591f, 0.365539f, 0.201267f, 0.487148f, 0.990369f, 0.912151f, 0.118349f, 0.025190f, 0.898638f, 0.537170f, 0.200190f, 0.673653f, 0.644223f, 0.122086f, 0.259600f, 0.060078f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.209860f, 0.132306f, 0.193236f, 0.685467f, 0.049500f, 0.101855f, 0.134174f, 0.316541f, 0.298750f, 0.255064f, 0.750537f, 0.998023f, 0.533978f, 0.944203f, 0.396610f, 0.106682f, 0.408774f, 0.296128f, 0.493407f, 0.657044f, 0.461050f, 0.935161f, 0.884765f, 0.701978f, 0.489685f, 0.131687f, 0.397014f, 0.704402f, 0.284886f, 0.103988f, 0.907898f, 0.709051f, 0.615276f, 0.792499f, 0.835646f, 0.483459f, 0.881188f, 0.916419f, 0.271551f, 0.607545f, 0.526584f, 0.537946f, 0.937663f, 0.305189f, 0.983434f, 0.902131f, 0.458723f, 0.817453f, 0.769047f, 0.677895f, 0.319834f, 0.196451f, 0.671528f, 0.842973f, 0.016253f, 0.642803f, 0.442873f, 0.898088f, 0.321473f, 0.474185f, 0.514767f, 0.140440f, 0.712892f, 0.830476f, 0.057909f, 0.291389f, 0.038045f, 0.956544f, 0.667169f, 0.964200f, 0.531494f, 0.802069f, 0.374414f, 0.353819f, 0.378268f, 0.657862f, 0.359453f, 0.900367f, 0.983275f, 0.030427f, 0.193623f, 0.112250f, 0.042364f, 0.227741f, 0.446793f, 0.836990f, 0.221824f, 0.493945f, 0.929619f, 0.667215f, 0.798079f, 0.550994f, 0.980466f, 0.588662f, 0.045511f, 0.197983f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.404774f, 0.601277f, 0.771931f, 0.413086f, 0.710058f, 0.789869f, 0.317260f, 0.979270f, 0.649656f, 0.880998f, 0.555938f, 0.741603f, 0.770544f, 0.908248f, 0.150350f, 0.558283f, 0.428379f, 0.923159f, 0.105095f, 0.982574f, 0.875451f, 0.073826f, 0.490966f, 0.717560f, 0.738152f, 0.906494f, 0.799865f, 0.310930f, 0.498435f, 0.701786f, 0.138437f, 0.193991f, 0.481042f, 0.298246f, 0.862559f, 0.586277f, 0.348665f, 0.848833f, 0.804878f, 0.998355f, 0.847308f, 0.414457f, 0.127499f, 0.840641f, 0.059758f, 0.350271f, 0.919738f, 0.960766f, 0.640565f, 0.688648f, 0.042454f, 0.514480f, 0.546868f, 0.340101f, 0.068597f, 0.228908f, 0.357984f, 0.435142f, 0.590927f, 0.722392f, 0.317632f, 0.328954f, 0.019692f, 0.040875f, 0.257822f, 0.740245f, 0.628314f, 0.769789f, 0.768919f, 0.856567f, 0.720319f, 0.979011f, 0.898825f, 0.586717f, 0.588158f, 0.034267f, 0.998527f, 0.131576f, 0.740347f, 0.821015f, 0.373055f, 0.196852f, 0.098760f, 0.748606f, 0.452654f, 0.713718f, 0.915408f, 0.146584f, 0.919171f, 0.411626f, 0.305267f, 0.943062f, 0.990652f, 0.198892f, 0.656838f, 0.106495f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f, 0.650914f, 0.827313f, 0.684499f, 0.417333f, 0.383066f, 0.393122f, 0.589712f, 0.881567f, 0.929066f, 0.053530f, 0.181622f, 0.112224f, 0.193335f, 0.346608f, 0.506532f, 0.629461f, 0.732142f, 0.890112f, 0.989088f, 0.662856f, 0.845365f, 0.778039f, 0.307532f, 0.875692f, 0.042763f, 0.000367f, 0.273733f, 0.462098f, 0.638363f, 0.101770f, 0.673010f, 0.801816f, 0.185313f, 0.415125f, 0.519985f, 0.451807f, 0.799830f, 0.960522f, 0.798953f, 0.077993f, 0.804936f, 0.066596f, 0.235970f, 0.153097f, 0.197519f, 0.528315f, 0.671690f, 0.470321f, 0.959696f, 0.240292f, 0.763140f, 0.870182f, 0.562066f, 0.456223f, 0.596184f, 0.428810f, 0.555194f, 0.416934f, 0.400470f, 0.695346f, 0.092851f, 0.166542f, 0.851198f, 0.771077f, 0.281454f, 0.377269f, 0.926027f, 0.818077f, 0.614346f, 0.221490f, 0.044252f, 0.431258f, 0.672627f, 0.828480f, 0.852689f, 0.032776f, 0.244157f, 0.339095f, 0.188732f, 0.802975f, 0.767466f, 0.516833f, 0.982926f, 0.144059f, 0.899652f, 0.116463f, 0.163182f, 0.696219f, 0.109570f, 0.565845f, 0.420234f, 0.728474f, 0.900675f, 0.769872f, 0.849690f, 0.032945f, 0.966147f, 0.791880f, 0.675689f, 0.244889f, 0.216457f, 0.166048f, 0.922757f, 0.294077f, 0.453094f, 0.493958f, 0.778172f, 0.844235f, 0.139073f, 0.426904f, 0.842855f, 0.818033f, 0.102414f, 0.156383f, 0.304199f, 0.075359f, 0.424663f, 0.107618f, 0.568218f, 0.246557f, 0.596433f, 0.117526f, 0.975884f, 0.932561f, 0.391797f, 0.242179f, 0.250398f, 0.483394f, 0.039993f, 0.639705f, 0.408303f, 0.377407f, 0.809365f, 0.709035f, 0.954334f, 0.351936f, 0.897543f, 0.769967f, 0.357425f, 0.621665f, 0.288570f, 0.874400f, 0.112427f, 0.212434f, 0.310196f, 0.515433f, 0.415953f, 0.231255f, 0.307874f, 0.945431f, 0.294181f, 0.353904f, 0.003710f, 0.845078f, 0.154841f, 0.204144f, 0.255265f, 0.884622f, 0.206451f, 0.797526f, 0.808049f, 0.927021f, 0.115561f, 0.217279f, 0.742898f, 0.196001f, 0.286330f, 0.166742f, 0.172697f, 0.481553f, 0.109683f, 0.321698f, 0.426594f, 0.024548f, 0.388333f, 0.094122f, 0.493579f, 0.825738f, 0.818422f, 0.080449f, 0.601228f, 0.834586f, 0.237973f, 0.761927f, 0.890764f, 0.806124f, 0.107301f, 0.009060f, 0.191724f, 0.270477f, 0.616183f, 0.384273f, 0.703407f, 0.353075f, 0.154425f, 0.312690f, 0.884324f, 0.958532f, 0.207513f, 0.788468f, 0.273349f, 0.887132f, 0.165546f, 0.665960f, 0.084211f, 0.973893f, 0.700633f, 0.841816f, 0.566669f, 0.476801f, 0.621882f, 0.528742f, 0.469384f, 0.759450f, 0.178201f, 0.171172f, 0.431843f, 0.320748f, 0.074125f, 0.844471f, 0.771603f, 0.543921f, 0.979325f, 0.072600f, 0.766669f, 0.266370f, 0.368599f, 0.219279f, 0.789038f, 0.144240f, 0.840017f, 0.661578f, 0.059023f, 0.810982f, 0.627756f, 0.904982f, 0.748722f, 0.561121f, 0.836547f, 0.278050f, 0.183033f, 0.403026f, 0.745233f, 0.526907f, 0.487676f, 0.000546f, 0.425402f, 0.063554f, 0.208253f, 0.932394f, 0.215398f, 0.858338f, 0.802893f, 0.159146f, 0.605712f, 0.115662f, 0.727888f, 0.637462f, 0.811939f, 0.479385f, 0.914863f, 0.049349f, 0.292889f, 0.715053f, 0.418109f, 0.172951f, 0.107211f, 0.817339f, 0.473143f, 0.882284f, 0.733289f, 0.409726f, 0.373511f, 0.515638f, 0.889060f, 0.737279f, 0.005153f, 0.694158f, 0.919507f, 0.710456f, 0.177006f, 0.483518f, 0.140316f, 0.358995f, 0.937117f, 0.923305f, 0.282837f, 0.339631f};
// {2, 3, 18, 8}
@@ -1116,7 +1149,7 @@ TEST(AttentionTest, Attention4DWithMask3DPastAndPresentQkMatmulCausal) {
std::vector k = {-0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f};
// {2, 3, 6, 4}
std::vector v = {-0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f};
- // {2, 1, 4, 13}
+ // {2, 1, 4, 18}
std::vector m = {-0.454545f, -0.444930f, -0.435315f, -0.425699f, -0.416084f, -0.406469f, -0.396853f, -0.387238f, -0.377622f, -0.368007f, -0.358392f, -0.348776f, -0.339161f, -0.329545f, -0.319930f, -0.310315f, -0.300699f, -0.291084f, -0.281469f, -0.271853f, -0.262238f, -0.252622f, -0.243007f, -0.233392f, -0.223776f, -0.214161f, -0.204545f, -0.194930f, -0.185315f, -0.175699f, -0.166084f, -0.156469f, -0.146853f, -0.137238f, -0.127622f, -0.118007f, -0.108392f, -0.098776f, -0.089161f, -0.079545f, -0.069930f, -0.060315f, -0.050699f, -0.041084f, -0.031469f, -0.021853f, -0.012238f, -0.002622f, 0.006993f, 0.016608f, 0.026224f, 0.035839f, 0.045455f, 0.055070f, 0.064685f, 0.074301f, 0.083916f, 0.093531f, 0.103147f, 0.112762f, 0.122378f, 0.131993f, 0.141608f, 0.151224f, 0.160839f, 0.170455f, 0.180070f, 0.189685f, 0.199301f, 0.208916f, 0.218531f, 0.228147f, 0.237762f, 0.247378f, 0.256993f, 0.266608f, 0.276224f, 0.285839f, 0.295455f, 0.305070f, 0.314685f, 0.324301f, 0.333916f, 0.343531f, 0.353147f, 0.362762f, 0.372378f, 0.381993f, 0.391608f, 0.401224f, 0.410839f, 0.420455f, 0.430070f, 0.439685f, 0.449301f, 0.458916f, 0.468531f, 0.478147f, 0.487762f, 0.497378f, 0.506993f, 0.516608f, 0.526224f, 0.535839f};
// {2, 3, 12, 4}
std::vector past_key = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f};
@@ -1132,7 +1165,7 @@ TEST(AttentionTest, Attention4DWithMask3DPastAndPresentQkMatmulCausal) {
// {2, 3, 4, 4}
std::vector y = {-0.393782f, -0.387694f, -0.381606f, -0.375519f, -0.397492f, -0.391304f, -0.385116f, -0.378928f, -0.397474f, -0.391207f, -0.384941f, -0.378674f, -0.394849f, -0.388519f, -0.382190f, -0.375860f, -0.226271f, -0.220186f, -0.214101f, -0.208016f, -0.230042f, -0.223857f, -0.217672f, -0.211488f, -0.230104f, -0.223841f, -0.217577f, -0.211314f, -0.227525f, -0.221197f, -0.214870f, -0.208543f, -0.058757f, -0.052674f, -0.046592f, -0.040510f, -0.062587f, -0.056406f, -0.050224f, -0.044042f, -0.062730f, -0.056470f, -0.050209f, -0.043949f, -0.060198f, -0.053873f, -0.047548f, -0.041223f, 0.108760f, 0.114840f, 0.120919f, 0.126999f, 0.104873f, 0.111051f, 0.117229f, 0.123408f, 0.104648f, 0.110906f, 0.117163f, 0.123421f, 0.107131f, 0.113454f, 0.119777f, 0.126099f, 0.276279f, 0.282356f, 0.288433f, 0.294510f, 0.272337f, 0.278512f, 0.284687f, 0.290862f, 0.272031f, 0.278286f, 0.284540f, 0.290794f, 0.274463f, 0.280783f, 0.287104f, 0.293424f, 0.443800f, 0.449874f, 0.455949f, 0.462023f, 0.439807f, 0.445978f, 0.452150f, 0.458321f, 0.439418f, 0.445669f, 0.451921f, 0.458172f, 0.441797f, 0.448115f, 0.454433f, 0.460751f};
- // {2, 3, 13, 4}
+ // {2, 3, 12, 4}
std::vector present_key = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f};
// {2, 3, 18, 8}
std::vector present_value = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f};
@@ -1151,28 +1184,28 @@ TEST(AttentionTest, Attention4DWithMask3DPastAndPresentQkMatmulCausal) {
);
}
-TEST(AttentionTest, Attention4DWithMask4DPastAndPresentQkMatmul) {
- int batch_size = 2; // Q.shape[0]
- int q_num_heads = 3; // Q.shape[1]
- int q_sequence_length = 4; // Q.shape[2]
- int head_size = 4; // Q.shape[3]
- int kv_sequence_length = 6; // K.shape[2] and V.shape[2]
- int kv_num_heads = 3; // K.shape[1] and V.shape[1]
- int v_head_size = 4; // V.shape[3]
- int past_sequence_length = 7; // past_key.shape[2] and past_value.shape[2]
+TEST(AttentionTest, TestAttention4DWithPastAndPresentQkMatmulBias4DMaskCausal) {
+ int batch_size = 2; // Q.shape[0]
+ int q_num_heads = 3; // Q.shape[1]
+ int q_sequence_length = 4; // Q.shape[2]
+ int head_size = 8; // Q.shape[3]
+ int kv_sequence_length = 6; // K.shape[2] and V.shape[2]
+ int kv_num_heads = 3; // K.shape[1] and V.shape[1]
+ int v_head_size = 8; // V.shape[3]
+ int past_sequence_length = 12; // past_key.shape[2] and past_value.shape[2]
- // {2, 3, 4, 4}
- std::vector q = {-0.454545f, -0.444129f, -0.433712f, -0.423295f, -0.412879f, -0.402462f, -0.392045f, -0.381629f, -0.371212f, -0.360795f, -0.350379f, -0.339962f, -0.329545f, -0.319129f, -0.308712f, -0.298295f, -0.287879f, -0.277462f, -0.267045f, -0.256629f, -0.246212f, -0.235795f, -0.225379f, -0.214962f, -0.204545f, -0.194129f, -0.183712f, -0.173295f, -0.162879f, -0.152462f, -0.142045f, -0.131629f, -0.121212f, -0.110795f, -0.100379f, -0.089962f, -0.079545f, -0.069129f, -0.058712f, -0.048295f, -0.037879f, -0.027462f, -0.017045f, -0.006629f, 0.003788f, 0.014205f, 0.024621f, 0.035038f, 0.045455f, 0.055871f, 0.066288f, 0.076705f, 0.087121f, 0.097538f, 0.107955f, 0.118371f, 0.128788f, 0.139205f, 0.149621f, 0.160038f, 0.170455f, 0.180871f, 0.191288f, 0.201705f, 0.212121f, 0.222538f, 0.232955f, 0.243371f, 0.253788f, 0.264205f, 0.274621f, 0.285038f, 0.295455f, 0.305871f, 0.316288f, 0.326705f, 0.337121f, 0.347538f, 0.357955f, 0.368371f, 0.378788f, 0.389205f, 0.399621f, 0.410038f, 0.420455f, 0.430871f, 0.441288f, 0.451705f, 0.462121f, 0.472538f, 0.482955f, 0.493371f, 0.503788f, 0.514205f, 0.524621f, 0.535038f};
- // {2, 3, 6, 4}
- std::vector k = {-0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f};
+ // {2, 3, 4, 8}
+ std::vector q = {0.548814f, 0.715189f, 0.602763f, 0.544883f, 0.423655f, 0.645894f, 0.437587f, 0.891773f, 0.963663f, 0.383442f, 0.791725f, 0.528895f, 0.568045f, 0.925597f, 0.071036f, 0.087129f, 0.020218f, 0.832620f, 0.778157f, 0.870012f, 0.978618f, 0.799159f, 0.461479f, 0.780529f, 0.118274f, 0.639921f, 0.143353f, 0.944669f, 0.521848f, 0.414662f, 0.264556f, 0.774234f, 0.456150f, 0.568434f, 0.018790f, 0.617635f, 0.612096f, 0.616934f, 0.943748f, 0.681820f, 0.359508f, 0.437032f, 0.697631f, 0.060225f, 0.666767f, 0.670638f, 0.210383f, 0.128926f, 0.315428f, 0.363711f, 0.570197f, 0.438602f, 0.988374f, 0.102045f, 0.208877f, 0.161310f, 0.653108f, 0.253292f, 0.466311f, 0.244426f, 0.158970f, 0.110375f, 0.656330f, 0.138183f, 0.196582f, 0.368725f, 0.820993f, 0.097101f, 0.837945f, 0.096098f, 0.976459f, 0.468651f, 0.976761f, 0.604846f, 0.739264f, 0.039188f, 0.282807f, 0.120197f, 0.296140f, 0.118728f, 0.317983f, 0.414263f, 0.064147f, 0.692472f, 0.566601f, 0.265390f, 0.523248f, 0.093941f, 0.575947f, 0.929296f, 0.318569f, 0.667410f, 0.131798f, 0.716327f, 0.289406f, 0.183191f, 0.586513f, 0.020108f, 0.828940f, 0.004695f, 0.677817f, 0.270008f, 0.735194f, 0.962189f, 0.248753f, 0.576157f, 0.592042f, 0.572252f, 0.223082f, 0.952749f, 0.447125f, 0.846409f, 0.699479f, 0.297437f, 0.813798f, 0.396506f, 0.881103f, 0.581273f, 0.881735f, 0.692532f, 0.725254f, 0.501324f, 0.956084f, 0.643990f, 0.423855f, 0.606393f, 0.019193f, 0.301575f, 0.660174f, 0.290078f, 0.618015f, 0.428769f, 0.135474f, 0.298282f, 0.569965f, 0.590873f, 0.574325f, 0.653201f, 0.652103f, 0.431418f, 0.896547f, 0.367562f, 0.435865f, 0.891923f, 0.806194f, 0.703889f, 0.100227f, 0.919483f, 0.714241f, 0.998847f, 0.149448f, 0.868126f, 0.162493f, 0.615560f, 0.123820f, 0.848008f, 0.807319f, 0.569101f, 0.407183f, 0.069167f, 0.697429f, 0.453543f, 0.722056f, 0.866382f, 0.975522f, 0.855803f, 0.011714f, 0.359978f, 0.729991f, 0.171630f, 0.521037f, 0.054338f, 0.199997f, 0.018522f, 0.793698f, 0.223925f, 0.345352f, 0.928081f, 0.704414f, 0.031839f, 0.164694f, 0.621478f, 0.577229f, 0.237893f, 0.934214f, 0.613966f, 0.535633f, 0.589910f, 0.730122f, 0.311945f, 0.398221f, 0.209844f};
+ // {2, 3, 6, 8}
+ std::vector k = {0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f};
// {2, 3, 6, 4}
- std::vector v = {-0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f};
- // {2, 3, 4, 13}
- std::vector m = {-0.454545f, -0.451340f, -0.448135f, -0.444930f, -0.441725f, -0.438520f, -0.435315f, -0.432110f, -0.428904f, -0.425699f, -0.422494f, -0.419289f, -0.416084f, -0.412879f, -0.409674f, -0.406469f, -0.403263f, -0.400058f, -0.396853f, -0.393648f, -0.390443f, -0.387238f, -0.384033f, -0.380828f, -0.377622f, -0.374417f, -0.371212f, -0.368007f, -0.364802f, -0.361597f, -0.358392f, -0.355186f, -0.351981f, -0.348776f, -0.345571f, -0.342366f, -0.339161f, -0.335956f, -0.332751f, -0.329545f, -0.326340f, -0.323135f, -0.319930f, -0.316725f, -0.313520f, -0.310315f, -0.307110f, -0.303904f, -0.300699f, -0.297494f, -0.294289f, -0.291084f, -0.287879f, -0.284674f, -0.281469f, -0.278263f, -0.275058f, -0.271853f, -0.268648f, -0.265443f, -0.262238f, -0.259033f, -0.255828f, -0.252622f, -0.249417f, -0.246212f, -0.243007f, -0.239802f, -0.236597f, -0.233392f, -0.230186f, -0.226981f, -0.223776f, -0.220571f, -0.217366f, -0.214161f, -0.210956f, -0.207751f, -0.204545f, -0.201340f, -0.198135f, -0.194930f, -0.191725f, -0.188520f, -0.185315f, -0.182110f, -0.178904f, -0.175699f, -0.172494f, -0.169289f, -0.166084f, -0.162879f, -0.159674f, -0.156469f, -0.153263f, -0.150058f, -0.146853f, -0.143648f, -0.140443f, -0.137238f, -0.134033f, -0.130828f, -0.127622f, -0.124417f, -0.121212f, -0.118007f, -0.114802f, -0.111597f, -0.108392f, -0.105186f, -0.101981f, -0.098776f, -0.095571f, -0.092366f, -0.089161f, -0.085956f, -0.082751f, -0.079545f, -0.076340f, -0.073135f, -0.069930f, -0.066725f, -0.063520f, -0.060315f, -0.057110f, -0.053904f, -0.050699f, -0.047494f, -0.044289f, -0.041084f, -0.037879f, -0.034674f, -0.031469f, -0.028263f, -0.025058f, -0.021853f, -0.018648f, -0.015443f, -0.012238f, -0.009033f, -0.005828f, -0.002622f, 0.000583f, 0.003788f, 0.006993f, 0.010198f, 0.013403f, 0.016608f, 0.019814f, 0.023019f, 0.026224f, 0.029429f, 0.032634f, 0.035839f, 0.039044f, 0.042249f, 0.045455f, 0.048660f, 0.051865f, 0.055070f, 0.058275f, 0.061480f, 0.064685f, 0.067890f, 0.071096f, 0.074301f, 0.077506f, 0.080711f, 0.083916f, 0.087121f, 0.090326f, 0.093531f, 0.096737f, 0.099942f, 0.103147f, 0.106352f, 0.109557f, 0.112762f, 0.115967f, 0.119172f, 0.122378f, 0.125583f, 0.128788f, 0.131993f, 0.135198f, 0.138403f, 0.141608f, 0.144814f, 0.148019f, 0.151224f, 0.154429f, 0.157634f, 0.160839f, 0.164044f, 0.167249f, 0.170455f, 0.173660f, 0.176865f, 0.180070f, 0.183275f, 0.186480f, 0.189685f, 0.192890f, 0.196096f, 0.199301f, 0.202506f, 0.205711f, 0.208916f, 0.212121f, 0.215326f, 0.218531f, 0.221737f, 0.224942f, 0.228147f, 0.231352f, 0.234557f, 0.237762f, 0.240967f, 0.244172f, 0.247378f, 0.250583f, 0.253788f, 0.256993f, 0.260198f, 0.263403f, 0.266608f, 0.269814f, 0.273019f, 0.276224f, 0.279429f, 0.282634f, 0.285839f, 0.289044f, 0.292249f, 0.295455f, 0.298660f, 0.301865f, 0.305070f, 0.308275f, 0.311480f, 0.314685f, 0.317890f, 0.321096f, 0.324301f, 0.327506f, 0.330711f, 0.333916f, 0.337121f, 0.340326f, 0.343531f, 0.346737f, 0.349942f, 0.353147f, 0.356352f, 0.359557f, 0.362762f, 0.365967f, 0.369172f, 0.372378f, 0.375583f, 0.378788f, 0.381993f, 0.385198f, 0.388403f, 0.391608f, 0.394814f, 0.398019f, 0.401224f, 0.404429f, 0.407634f, 0.410839f, 0.414044f, 0.417249f, 0.420455f, 0.423660f, 0.426865f, 0.430070f, 0.433275f, 0.436480f, 0.439685f, 0.442890f, 0.446096f, 0.449301f, 0.452506f, 0.455711f, 0.458916f, 0.462121f, 0.465326f, 0.468531f, 0.471737f, 0.474942f, 0.478147f, 0.481352f, 0.484557f, 0.487762f, 0.490967f, 0.494172f, 0.497378f, 0.500583f, 0.503788f, 0.506993f, 0.510198f, 0.513403f, 0.516608f, 0.519814f, 0.523019f, 0.526224f, 0.529429f, 0.532634f, 0.535839f, 0.539044f, 0.542249f};
- // {2, 3, 12, 4}
- std::vector past_key = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f};
- // {2, 3, 12, 4}
- std::vector past_value = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f};
+ std::vector v = {0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f};
+ // {2, 3, 4, 18}
+ std::vector m = {0.966147f, 0.791880f, 0.675689f, 0.244889f, 0.216457f, 0.166048f, 0.922757f, 0.294077f, 0.453094f, 0.493958f, 0.778172f, 0.844235f, 0.139073f, 0.426904f, 0.842855f, 0.818033f, 0.102414f, 0.156383f, 0.304199f, 0.075359f, 0.424663f, 0.107618f, 0.568218f, 0.246557f, 0.596433f, 0.117526f, 0.975884f, 0.932561f, 0.391797f, 0.242179f, 0.250398f, 0.483394f, 0.039993f, 0.639705f, 0.408303f, 0.377407f, 0.809365f, 0.709035f, 0.954334f, 0.351936f, 0.897543f, 0.769967f, 0.357425f, 0.621665f, 0.288570f, 0.874400f, 0.112427f, 0.212434f, 0.183033f, 0.403026f, 0.745233f, 0.526907f, 0.487676f, 0.000546f, 0.425402f, 0.063554f, 0.208253f, 0.932394f, 0.215398f, 0.858338f, 0.802893f, 0.159146f, 0.605712f, 0.115662f, 0.727888f, 0.637462f, 0.811939f, 0.479385f, 0.914863f, 0.049349f, 0.292889f, 0.715053f, 0.418109f, 0.172951f, 0.107211f, 0.817339f, 0.473143f, 0.882284f, 0.733289f, 0.409726f, 0.373511f, 0.515638f, 0.889060f, 0.737279f, 0.005153f, 0.694158f, 0.919507f, 0.710456f, 0.177006f, 0.483518f, 0.140316f, 0.358995f, 0.937117f, 0.923305f, 0.282837f, 0.339631f, 0.600213f, 0.963197f, 0.147801f, 0.256917f, 0.873557f, 0.491892f, 0.898961f, 0.185518f, 0.532669f, 0.326270f, 0.316543f, 0.446877f, 0.433077f, 0.357347f, 0.914971f, 0.731744f, 0.727547f, 0.289913f, 0.577709f, 0.779179f, 0.795590f, 0.344530f, 0.770873f, 0.735894f, 0.141506f, 0.865945f, 0.441321f, 0.486410f, 0.448369f, 0.567846f, 0.621169f, 0.498180f, 0.866789f, 0.627735f, 0.401428f, 0.416692f, 0.810839f, 0.348192f, 0.211455f, 0.059383f, 0.876027f, 0.918546f, 0.120120f, 0.334474f, 0.175372f, 0.115898f, 0.899867f, 0.056877f, 0.980486f, 0.096451f, 0.863471f, 0.566506f, 0.367917f, 0.342342f, 0.757364f, 0.314573f, 0.657319f, 0.517326f, 0.484966f, 0.901162f, 0.554645f, 0.826862f, 0.725574f, 0.038557f, 0.773110f, 0.216870f, 0.903150f, 0.042924f, 0.333072f, 0.099733f, 0.475589f, 0.820022f, 0.298187f, 0.150935f, 0.330267f, 0.813880f, 0.140384f, 0.227362f, 0.068852f, 0.705710f, 0.395233f, 0.310840f, 0.718626f, 0.335978f, 0.727771f, 0.815199f, 0.217663f, 0.973819f, 0.162358f, 0.290841f, 0.179795f, 0.345506f, 0.480061f, 0.522176f, 0.853606f, 0.889448f, 0.220104f, 0.622894f, 0.111496f, 0.458970f, 0.322334f, 0.316501f, 0.482584f, 0.729828f, 0.069183f, 0.879173f, 0.734814f, 0.176499f, 0.939161f, 0.506312f, 0.999809f, 0.197259f, 0.534908f, 0.290248f, 0.304174f, 0.591065f, 0.921719f, 0.805264f, 0.723941f, 0.559174f, 0.922298f, 0.492361f, 0.873832f, 0.833982f, 0.213835f, 0.771225f, 0.012171f, 0.322830f, 0.229567f, 0.506863f, 0.736853f, 0.097676f, 0.514922f, 0.938412f, 0.228647f, 0.677141f, 0.592880f, 0.010064f, 0.475826f, 0.708770f, 0.043975f, 0.879521f, 0.520081f, 0.030661f, 0.224414f, 0.953676f, 0.582320f, 0.107473f, 0.287544f, 0.456704f, 0.020950f, 0.411616f, 0.489459f, 0.243678f, 0.588639f, 0.753240f, 0.235834f, 0.620500f, 0.639622f, 0.948540f, 0.778276f, 0.848345f, 0.490420f, 0.185349f, 0.995815f, 0.129356f, 0.471457f, 0.068093f, 0.943851f, 0.964925f, 0.719389f, 0.349993f, 0.254382f, 0.265303f, 0.127294f, 0.525809f, 0.141817f, 0.316731f, 0.626706f, 0.727544f, 0.024273f, 0.430116f, 0.652125f, 0.853246f, 0.475325f, 0.969206f, 0.265633f, 0.013509f, 0.483753f, 0.256114f, 0.823718f, 0.232773f, 0.310629f, 0.791227f, 0.715143f, 0.558051f, 0.704948f, 0.418637f, 0.005310f, 0.011355f, 0.511222f, 0.083291f, 0.051075f, 0.965517f, 0.859003f, 0.152027f, 0.000664f, 0.941668f, 0.278325f, 0.185898f, 0.691508f, 0.108904f, 0.264650f, 0.975095f, 0.639463f, 0.520678f, 0.397919f, 0.774501f, 0.140957f, 0.967338f, 0.861123f, 0.617657f, 0.042906f, 0.700856f, 0.913284f, 0.524577f, 0.354225f, 0.120277f, 0.754901f, 0.885022f, 0.100252f, 0.758985f, 0.017060f, 0.967055f, 0.615058f, 0.552439f, 0.295950f, 0.929292f, 0.265906f, 0.828147f, 0.985109f, 0.783397f, 0.518990f, 0.066074f, 0.472414f, 0.438256f, 0.202796f, 0.423588f, 0.357758f, 0.163684f, 0.441374f, 0.262800f, 0.522062f, 0.035160f, 0.906231f, 0.816364f, 0.552581f, 0.851809f, 0.962395f, 0.110522f, 0.630832f, 0.997994f, 0.987889f, 0.603323f, 0.128021f, 0.583193f, 0.002065f, 0.198911f, 0.956123f, 0.330441f, 0.638390f, 0.280860f, 0.947822f, 0.728559f, 0.329651f, 0.791761f, 0.108166f, 0.392319f, 0.221218f, 0.683726f, 0.102446f, 0.397026f, 0.276650f, 0.506343f, 0.349898f, 0.706411f, 0.024577f, 0.633987f, 0.230571f, 0.268709f, 0.800256f, 0.955568f, 0.316550f, 0.826805f, 0.103991f, 0.633982f, 0.751032f, 0.155978f, 0.426002f, 0.892707f, 0.103578f, 0.018096f, 0.590585f, 0.435532f, 0.798689f, 0.923456f, 0.299154f, 0.388404f, 0.486272f, 0.588151f, 0.983854f, 0.697330f, 0.389549f, 0.263768f, 0.944626f, 0.135548f, 0.720266f, 0.925395f, 0.664666f, 0.423054f, 0.198991f, 0.367475f, 0.706872f, 0.649534f, 0.927976f, 0.866861f, 0.816151f, 0.911451f, 0.276337f, 0.369524f, 0.379894f, 0.560451f, 0.668218f, 0.286717f, 0.019462f, 0.399222f};
+ // {2, 3, 12, 8}
+ std::vector past_key = {0.308528f, 0.942185f, 0.888265f, 0.860311f, 0.653000f, 0.344289f, 0.548849f, 0.815225f, 0.098610f, 0.801075f, 0.041180f, 0.816421f, 0.807564f, 0.051007f, 0.627161f, 0.502453f, 0.169820f, 0.148379f, 0.773259f, 0.567693f, 0.982999f, 0.982248f, 0.992667f, 0.118616f, 0.938256f, 0.244570f, 0.458212f, 0.757407f, 0.203621f, 0.566312f, 0.185817f, 0.104736f, 0.116559f, 0.357639f, 0.004655f, 0.424854f, 0.664197f, 0.401688f, 0.085795f, 0.062689f, 0.278117f, 0.169313f, 0.965095f, 0.151230f, 0.805462f, 0.586108f, 0.569287f, 0.512081f, 0.971763f, 0.363845f, 0.787916f, 0.555294f, 0.395634f, 0.955466f, 0.598316f, 0.118917f, 0.417539f, 0.781582f, 0.693747f, 0.916340f, 0.259377f, 0.758194f, 0.459875f, 0.573610f, 0.955047f, 0.979286f, 0.861591f, 0.359097f, 0.887701f, 0.638609f, 0.429997f, 0.035743f, 0.770128f, 0.502106f, 0.786188f, 0.748023f, 0.793567f, 0.300651f, 0.800799f, 0.548846f, 0.473326f, 0.675126f, 0.021359f, 0.102317f, 0.292177f, 0.982990f, 0.139746f, 0.330596f, 0.051053f, 0.331269f, 0.320326f, 0.946807f, 0.845154f, 0.382764f, 0.024769f, 0.831031f, 0.660536f, 0.152364f, 0.996071f, 0.100233f, 0.867115f, 0.294266f, 0.435353f, 0.795457f, 0.677508f, 0.937864f, 0.621140f, 0.097810f, 0.884360f, 0.769156f, 0.711870f, 0.053734f, 0.396223f, 0.167436f, 0.821904f, 0.700529f, 0.883078f, 0.966575f, 0.774748f, 0.994233f, 0.614770f, 0.037130f, 0.014252f, 0.342104f, 0.823472f, 0.866135f, 0.960813f, 0.065121f, 0.044571f, 0.913284f, 0.305047f, 0.557987f, 0.982445f, 0.400449f, 0.665871f, 0.400880f, 0.768195f, 0.527715f, 0.237523f, 0.271306f, 0.258059f, 0.532320f, 0.703189f, 0.949280f, 0.694087f, 0.781193f, 0.168926f, 0.374063f, 0.413780f, 0.686380f, 0.295892f, 0.303292f, 0.355889f, 0.810302f, 0.577590f, 0.075277f, 0.078246f, 0.371287f, 0.766591f, 0.688683f, 0.707982f, 0.767210f, 0.287153f, 0.548256f, 0.543353f, 0.739632f, 0.956871f, 0.277990f, 0.793282f, 0.659971f, 0.580238f, 0.774880f, 0.944032f, 0.036691f, 0.147400f, 0.756287f, 0.083791f, 0.516124f, 0.219861f, 0.274296f, 0.701840f, 0.030193f, 0.873319f, 0.444479f, 0.502393f, 0.540048f, 0.645544f, 0.344857f, 0.101107f, 0.318379f, 0.168142f, 0.556133f, 0.318029f, 0.958067f, 0.965734f, 0.620126f, 0.617497f, 0.985379f, 0.887283f, 0.765070f, 0.313591f, 0.365539f, 0.201267f, 0.487148f, 0.990369f, 0.912151f, 0.118349f, 0.025190f, 0.898638f, 0.537170f, 0.200190f, 0.673653f, 0.644223f, 0.122086f, 0.259600f, 0.060078f, 0.209860f, 0.132306f, 0.193236f, 0.685467f, 0.049500f, 0.101855f, 0.134174f, 0.316541f, 0.298750f, 0.255064f, 0.750537f, 0.998023f, 0.533978f, 0.944203f, 0.396610f, 0.106682f, 0.408774f, 0.296128f, 0.493407f, 0.657044f, 0.461050f, 0.935161f, 0.884765f, 0.701978f, 0.489685f, 0.131687f, 0.397014f, 0.704402f, 0.284886f, 0.103988f, 0.907898f, 0.709051f, 0.615276f, 0.792499f, 0.835646f, 0.483459f, 0.881188f, 0.916419f, 0.271551f, 0.607545f, 0.526584f, 0.537946f, 0.937663f, 0.305189f, 0.983434f, 0.902131f, 0.458723f, 0.817453f, 0.769047f, 0.677895f, 0.319834f, 0.196451f, 0.671528f, 0.842973f, 0.016253f, 0.642803f, 0.442873f, 0.898088f, 0.321473f, 0.474185f, 0.514767f, 0.140440f, 0.712892f, 0.830476f, 0.057909f, 0.291389f, 0.038045f, 0.956544f, 0.667169f, 0.964200f, 0.531494f, 0.802069f, 0.374414f, 0.353819f, 0.378268f, 0.657862f, 0.359453f, 0.900367f, 0.983275f, 0.030427f, 0.193623f, 0.112250f, 0.042364f, 0.227741f, 0.446793f, 0.836990f, 0.221824f, 0.493945f, 0.929619f, 0.667215f, 0.798079f, 0.550994f, 0.980466f, 0.588662f, 0.045511f, 0.197983f, 0.404774f, 0.601277f, 0.771931f, 0.413086f, 0.710058f, 0.789869f, 0.317260f, 0.979270f, 0.649656f, 0.880998f, 0.555938f, 0.741603f, 0.770544f, 0.908248f, 0.150350f, 0.558283f, 0.428379f, 0.923159f, 0.105095f, 0.982574f, 0.875451f, 0.073826f, 0.490966f, 0.717560f, 0.738152f, 0.906494f, 0.799865f, 0.310930f, 0.498435f, 0.701786f, 0.138437f, 0.193991f, 0.481042f, 0.298246f, 0.862559f, 0.586277f, 0.348665f, 0.848833f, 0.804878f, 0.998355f, 0.847308f, 0.414457f, 0.127499f, 0.840641f, 0.059758f, 0.350271f, 0.919738f, 0.960766f, 0.640565f, 0.688648f, 0.042454f, 0.514480f, 0.546868f, 0.340101f, 0.068597f, 0.228908f, 0.357984f, 0.435142f, 0.590927f, 0.722392f, 0.317632f, 0.328954f, 0.019692f, 0.040875f, 0.257822f, 0.740245f, 0.628314f, 0.769789f, 0.768919f, 0.856567f, 0.720319f, 0.979011f, 0.898825f, 0.586717f, 0.588158f, 0.034267f, 0.998527f, 0.131576f, 0.740347f, 0.821015f, 0.373055f, 0.196852f, 0.098760f, 0.748606f, 0.452654f, 0.713718f, 0.915408f, 0.146584f, 0.919171f, 0.411626f, 0.305267f, 0.943062f, 0.990652f, 0.198892f, 0.656838f, 0.106495f, 0.650914f, 0.827313f, 0.684499f, 0.417333f, 0.383066f, 0.393122f, 0.589712f, 0.881567f, 0.929066f, 0.053530f, 0.181622f, 0.112224f, 0.193335f, 0.346608f, 0.506532f, 0.629461f, 0.732142f, 0.890112f, 0.989088f, 0.662856f, 0.845365f, 0.778039f, 0.307532f, 0.875692f, 0.042763f, 0.000367f, 0.273733f, 0.462098f, 0.638363f, 0.101770f, 0.673010f, 0.801816f, 0.185313f, 0.415125f, 0.519985f, 0.451807f, 0.799830f, 0.960522f, 0.798953f, 0.077993f, 0.804936f, 0.066596f, 0.235970f, 0.153097f, 0.197519f, 0.528315f, 0.671690f, 0.470321f, 0.959696f, 0.240292f, 0.763140f, 0.870182f, 0.562066f, 0.456223f, 0.596184f, 0.428810f, 0.555194f, 0.416934f, 0.400470f, 0.695346f, 0.092851f, 0.166542f, 0.851198f, 0.771077f, 0.281454f, 0.377269f, 0.926027f, 0.818077f, 0.614346f, 0.221490f, 0.044252f, 0.431258f, 0.672627f, 0.828480f, 0.852689f, 0.032776f, 0.244157f, 0.339095f, 0.188732f, 0.802975f, 0.767466f, 0.516833f, 0.982926f, 0.144059f, 0.899652f, 0.116463f, 0.163182f, 0.696219f, 0.109570f, 0.565845f, 0.420234f, 0.728474f, 0.900675f, 0.769872f, 0.849690f, 0.032945f, 0.310196f, 0.515433f, 0.415953f, 0.231255f, 0.307874f, 0.945431f, 0.294181f, 0.353904f, 0.003710f, 0.845078f, 0.154841f, 0.204144f, 0.255265f, 0.884622f, 0.206451f, 0.797526f, 0.808049f, 0.927021f, 0.115561f, 0.217279f, 0.742898f, 0.196001f, 0.286330f, 0.166742f, 0.172697f, 0.481553f, 0.109683f, 0.321698f, 0.426594f, 0.024548f, 0.388333f, 0.094122f, 0.493579f, 0.825738f, 0.818422f, 0.080449f, 0.601228f, 0.834586f, 0.237973f, 0.761927f, 0.890764f, 0.806124f, 0.107301f, 0.009060f, 0.191724f, 0.270477f, 0.616183f, 0.384273f, 0.703407f, 0.353075f, 0.154425f, 0.312690f, 0.884324f, 0.958532f, 0.207513f, 0.788468f, 0.273349f, 0.887132f, 0.165546f, 0.665960f, 0.084211f, 0.973893f, 0.700633f, 0.841816f, 0.566669f, 0.476801f, 0.621882f, 0.528742f, 0.469384f, 0.759450f, 0.178201f, 0.171172f};
+ // {2, 3, 12, 8}
+ std::vector past_value = {0.431843f, 0.320748f, 0.074125f, 0.844471f, 0.771603f, 0.543921f, 0.979325f, 0.072600f, 0.766669f, 0.266370f, 0.368599f, 0.219279f, 0.789038f, 0.144240f, 0.840017f, 0.661578f, 0.059023f, 0.810982f, 0.627756f, 0.904982f, 0.748722f, 0.561121f, 0.836547f, 0.278050f, 0.546950f, 0.293617f, 0.968204f, 0.226196f, 0.015738f, 0.325855f, 0.502509f, 0.028363f, 0.559248f, 0.874283f, 0.704732f, 0.622968f, 0.955962f, 0.958279f, 0.824266f, 0.607742f, 0.487765f, 0.013316f, 0.606262f, 0.989088f, 0.818101f, 0.340605f, 0.152047f, 0.784059f, 0.743938f, 0.967047f, 0.874842f, 0.555663f, 0.101284f, 0.483501f, 0.313695f, 0.512408f, 0.301702f, 0.861823f, 0.844327f, 0.315465f, 0.599581f, 0.430181f, 0.909093f, 0.187361f, 0.697728f, 0.970375f, 0.175276f, 0.201966f, 0.693723f, 0.779154f, 0.490549f, 0.609686f, 0.212682f, 0.476614f, 0.112072f, 0.321422f, 0.284780f, 0.444625f, 0.930126f, 0.181268f, 0.401388f, 0.615597f, 0.946557f, 0.133148f, 0.917877f, 0.081054f, 0.480741f, 0.454590f, 0.209603f, 0.347460f, 0.454165f, 0.865211f, 0.955064f, 0.518926f, 0.870100f, 0.608172f, 0.349087f, 0.194194f, 0.413135f, 0.522824f, 0.044443f, 0.145841f, 0.600184f, 0.225002f, 0.837326f, 0.326942f, 0.104834f, 0.083531f, 0.937123f, 0.118020f, 0.140910f, 0.862666f, 0.254288f, 0.665951f, 0.816726f, 0.607181f, 0.957489f, 0.708883f, 0.112752f, 0.558410f, 0.718186f, 0.801957f, 0.026321f, 0.718879f, 0.825681f, 0.746834f, 0.512349f, 0.458021f, 0.549419f, 0.704644f, 0.922914f, 0.617035f, 0.887834f, 0.701257f, 0.068336f, 0.500828f, 0.286486f, 0.285175f, 0.355928f, 0.314733f, 0.578610f, 0.683601f, 0.268749f, 0.129763f, 0.058809f, 0.575753f, 0.186130f, 0.009248f, 0.927753f, 0.537140f, 0.092448f, 0.842921f, 0.983203f, 0.448601f, 0.042490f, 0.117546f, 0.381654f, 0.885523f, 0.148039f, 0.823990f, 0.014976f, 0.457389f, 0.644397f, 0.060379f, 0.614763f, 0.944404f, 0.160260f, 0.729611f, 0.609094f, 0.185116f, 0.006203f, 0.009284f, 0.532092f, 0.942779f, 0.644299f, 0.714300f, 0.493865f, 0.581889f, 0.126368f, 0.876821f, 0.760793f, 0.998199f, 0.297723f, 0.227018f, 0.125162f, 0.964210f, 0.780885f, 0.166325f, 0.552686f, 0.413768f, 0.151486f, 0.162073f, 0.963470f, 0.304964f, 0.941439f, 0.075611f, 0.460803f, 0.129619f, 0.004787f, 0.553766f, 0.113894f, 0.722025f, 0.698116f, 0.176333f, 0.941742f, 0.721043f, 0.297970f, 0.709234f, 0.731930f, 0.342226f, 0.375589f, 0.359107f, 0.616618f, 0.900410f, 0.173193f, 0.875200f, 0.027653f, 0.660339f, 0.414439f, 0.791282f, 0.721198f, 0.480108f, 0.643864f, 0.501773f, 0.811518f, 0.476084f, 0.523156f, 0.250521f, 0.605043f, 0.302905f, 0.577284f, 0.169678f, 0.159469f, 0.417030f, 0.426820f, 0.268109f, 0.131597f, 0.039211f, 0.025232f, 0.271550f, 0.461853f, 0.726243f, 0.474872f, 0.904051f, 0.035220f, 0.180661f, 0.338515f, 0.577496f, 0.852736f, 0.350202f, 0.267989f, 0.061889f, 0.821303f, 0.379666f, 0.571550f, 0.983555f, 0.001595f, 0.145450f, 0.779111f, 0.805128f, 0.769247f, 0.536999f, 0.978857f, 0.396185f, 0.601944f, 0.063369f, 0.409857f, 0.722500f, 0.238739f, 0.943828f, 0.686783f, 0.287575f, 0.768999f, 0.083165f, 0.974774f, 0.049285f, 0.933456f, 0.252854f, 0.757824f, 0.000074f, 0.254240f, 0.749101f, 0.532336f, 0.114952f, 0.393630f, 0.375549f, 0.568162f, 0.667977f, 0.840830f, 0.497231f, 0.392022f, 0.143977f, 0.804823f, 0.713370f, 0.408677f, 0.518432f, 0.665183f, 0.164806f, 0.027198f, 0.317504f, 0.595585f, 0.486606f, 0.692555f, 0.819690f, 0.488442f, 0.134267f, 0.850628f, 0.574990f, 0.739937f, 0.704665f, 0.968212f, 0.295307f, 0.705307f, 0.365676f, 0.395411f, 0.230595f, 0.344010f, 0.948297f, 0.292571f, 0.245991f, 0.583138f, 0.258036f, 0.473386f, 0.834176f, 0.230400f, 0.426691f, 0.610490f, 0.545629f, 0.974723f, 0.680370f, 0.739946f, 0.966956f, 0.414438f, 0.355380f, 0.043862f, 0.184204f, 0.237190f, 0.183504f, 0.754784f, 0.535883f, 0.667634f, 0.820462f, 0.230774f, 0.325924f, 0.708360f, 0.392759f, 0.029271f, 0.434955f, 0.908273f, 0.409021f, 0.332249f, 0.989525f, 0.644416f, 0.365998f, 0.102020f, 0.787849f, 0.708075f, 0.921916f, 0.217276f, 0.114924f, 0.724073f, 0.203396f, 0.176104f, 0.319807f, 0.816825f, 0.539537f, 0.045850f, 0.463895f, 0.683980f, 0.538368f, 0.572450f, 0.224777f, 0.847739f, 0.561399f, 0.713246f, 0.981864f, 0.428199f, 0.881067f, 0.007281f, 0.033407f, 0.590280f, 0.311449f, 0.248277f, 0.277935f, 0.318403f, 0.728948f, 0.569196f, 0.789036f, 0.830197f, 0.842935f, 0.414644f, 0.421273f, 0.926266f, 0.661764f, 0.080467f, 0.542187f, 0.356007f, 0.987435f, 0.013655f, 0.612181f, 0.723623f, 0.288907f, 0.973642f, 0.859537f, 0.915653f, 0.019232f, 0.569872f, 0.294650f, 0.849029f, 0.632850f, 0.538877f, 0.114588f, 0.540223f, 0.631904f, 0.955912f, 0.585051f, 0.967401f, 0.961606f, 0.650200f, 0.505908f, 0.466022f, 0.890379f, 0.028257f, 0.113808f, 0.102072f, 0.756935f, 0.339651f, 0.637969f, 0.603783f, 0.385828f, 0.531568f, 0.645139f, 0.940950f, 0.575634f, 0.614367f, 0.067856f, 0.952216f, 0.528082f, 0.801273f, 0.050291f, 0.420910f, 0.256975f, 0.266976f, 0.791454f, 0.623867f, 0.439745f, 0.010586f, 0.964928f, 0.962023f, 0.217552f, 0.041346f, 0.530199f, 0.951411f, 0.910396f, 0.584663f, 0.303549f, 0.329961f, 0.897914f, 0.491784f, 0.131116f, 0.248425f, 0.276795f, 0.123547f, 0.463044f, 0.916051f, 0.668783f, 0.072474f, 0.005495f, 0.276248f, 0.362693f, 0.776750f, 0.967006f, 0.387567f, 0.686690f, 0.994902f, 0.745667f, 0.636190f, 0.078075f, 0.323215f, 0.913392f, 0.201005f, 0.843590f, 0.696324f, 0.366324f, 0.529174f, 0.542806f, 0.714054f, 0.516556f, 0.133076f, 0.773455f, 0.406273f, 0.963094f, 0.283514f, 0.263079f, 0.333507f, 0.572317f, 0.894870f, 0.176282f, 0.279679f, 0.581680f, 0.454334f, 0.447323f, 0.820734f, 0.923878f, 0.481307f, 0.687352f, 0.801059f, 0.518366f, 0.294316f, 0.638085f, 0.585109f, 0.901563f, 0.052407f, 0.910131f, 0.534432f, 0.015676f, 0.344702f, 0.724334f, 0.488433f, 0.980159f, 0.422610f, 0.326635f, 0.821672f, 0.547907f, 0.682327f, 0.805702f, 0.671428f, 0.422408f, 0.124796f, 0.580248f, 0.897433f, 0.418892f, 0.910725f, 0.503528f, 0.620842f, 0.832989f, 0.564597f, 0.090969f, 0.980979f, 0.245849f, 0.710505f, 0.505113f, 0.478773f, 0.243941f, 0.722151f, 0.112788f, 0.990453f, 0.845374f, 0.534509f, 0.424553f, 0.286465f, 0.501591f, 0.879417f, 0.275006f, 0.500537f, 0.234550f, 0.337149f, 0.190261f, 0.990539f, 0.571497f, 0.732815f, 0.098250f, 0.366118f, 0.892640f, 0.084438f, 0.165483f, 0.625418f, 0.622789f, 0.838227f, 0.935493f, 0.141986f, 0.259374f, 0.427461f, 0.000903f, 0.069814f, 0.226491f};
ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size);
ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size);
@@ -1181,14 +1214,15 @@ TEST(AttentionTest, Attention4DWithMask4DPastAndPresentQkMatmul) {
ASSERT_EQ(past_key.size(), batch_size * kv_num_heads * past_sequence_length * head_size);
ASSERT_EQ(past_value.size(), batch_size * kv_num_heads * past_sequence_length * v_head_size);
- // {2, 3, 4, 4}
- std::vector y = {-0.385742f, -0.379327f, -0.372911f, -0.366496f, -0.385554f, -0.379139f, -0.372723f, -0.366308f, -0.385366f, -0.378950f, -0.372535f, -0.366119f, -0.385178f, -0.378762f, -0.372347f, -0.365931f, -0.218323f, -0.211907f, -0.205492f, -0.199076f, -0.218134f, -0.211719f, -0.205304f, -0.198888f, -0.217946f, -0.211531f, -0.205115f, -0.198700f, -0.217758f, -0.211342f, -0.204927f, -0.198512f, -0.050903f, -0.044487f, -0.038072f, -0.031657f, -0.050715f, -0.044299f, -0.037884f, -0.031468f, -0.050526f, -0.044111f, -0.037695f, -0.031280f, -0.050338f, -0.043922f, -0.037507f, -0.031092f, 0.116517f, 0.122932f, 0.129348f, 0.135763f, 0.116705f, 0.123121f, 0.129536f, 0.135952f, 0.116894f, 0.123309f, 0.129724f, 0.136140f, 0.117082f, 0.123497f, 0.129913f, 0.136328f, 0.283937f, 0.290352f, 0.296768f, 0.303183f, 0.284125f, 0.290540f, 0.296956f, 0.303371f, 0.284313f, 0.290729f, 0.297144f, 0.303559f, 0.284501f, 0.290917f, 0.297332f, 0.303747f, 0.451356f, 0.457772f, 0.464187f, 0.470602f, 0.451544f, 0.457960f, 0.464375f, 0.470790f, 0.451732f, 0.458148f, 0.464563f, 0.470978f, 0.451920f, 0.458336f, 0.464751f, 0.471166f};
- // {2, 3, 13, 4}
- std::vector present_key = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f};
+ // {2, 3, 4, 8}
+ std::vector y = {0.431265f, 0.558994f, 0.492979f, 0.535281f, 0.609591f, 0.466737f, 0.692090f, 0.412591f, 0.468058f, 0.623595f, 0.468127f, 0.483497f, 0.577278f, 0.512802f, 0.639767f, 0.427679f, 0.422704f, 0.532822f, 0.449594f, 0.560548f, 0.608427f, 0.476187f, 0.695694f, 0.425740f, 0.447270f, 0.528366f, 0.506840f, 0.501836f, 0.547248f, 0.457381f, 0.583533f, 0.471707f, 0.414727f, 0.517263f, 0.342732f, 0.363543f, 0.677046f, 0.664675f, 0.271455f, 0.479982f, 0.438313f, 0.537211f, 0.342649f, 0.402609f, 0.660072f, 0.631518f, 0.266481f, 0.501402f, 0.458457f, 0.519536f, 0.434125f, 0.443849f, 0.614893f, 0.636419f, 0.310940f, 0.497030f, 0.433312f, 0.522457f, 0.417441f, 0.405432f, 0.617509f, 0.592985f, 0.310558f, 0.490073f, 0.499459f, 0.430465f, 0.601451f, 0.404111f, 0.502848f, 0.415186f, 0.440655f, 0.478187f, 0.536562f, 0.376663f, 0.527310f, 0.363608f, 0.443744f, 0.476396f, 0.453812f, 0.498910f, 0.483497f, 0.433209f, 0.541590f, 0.366029f, 0.513807f, 0.477506f, 0.492110f, 0.527910f, 0.471458f, 0.419741f, 0.536529f, 0.407806f, 0.512188f, 0.467064f, 0.496260f, 0.519270f, 0.683252f, 0.426643f, 0.425275f, 0.457410f, 0.611686f, 0.591234f, 0.394568f, 0.446171f, 0.637484f, 0.426481f, 0.346779f, 0.466867f, 0.585075f, 0.558250f, 0.387627f, 0.507636f, 0.658808f, 0.467355f, 0.496107f, 0.556756f, 0.513309f, 0.520842f, 0.411220f, 0.451704f, 0.661693f, 0.463543f, 0.421647f, 0.486068f, 0.552701f, 0.484705f, 0.412050f, 0.449818f, 0.637941f, 0.564086f, 0.543446f, 0.530844f, 0.627347f, 0.520370f, 0.389963f, 0.520054f, 0.574335f, 0.604007f, 0.468559f, 0.473710f, 0.559229f, 0.504183f, 0.453090f, 0.564618f, 0.568083f, 0.541180f, 0.491888f, 0.485970f, 0.564150f, 0.506989f, 0.421426f, 0.544228f, 0.616426f, 0.467555f, 0.529898f, 0.487372f, 0.574411f, 0.471969f, 0.388121f, 0.485012f, 0.533687f, 0.523210f, 0.560021f, 0.490233f, 0.443149f, 0.420163f, 0.538998f, 0.606965f, 0.586616f, 0.478324f, 0.572142f, 0.517933f, 0.441955f, 0.411890f, 0.550505f, 0.604577f, 0.541173f, 0.473423f, 0.505749f, 0.473388f, 0.389025f, 0.498730f, 0.507861f, 0.584389f, 0.519963f, 0.461030f, 0.576878f, 0.471281f, 0.461238f, 0.496673f, 0.509573f, 0.568405f};
// {2, 3, 18, 8}
- std::vector present_value = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f};
- // {2, 3, 4, 13}
- std::vector qk_matmul = {0.391336f, 0.370435f, 0.349534f, 0.328633f, 0.307732f, 0.286831f, 0.265930f, 0.390055f, 0.365671f, 0.341286f, 0.316902f, 0.292517f, 0.268133f, 0.354201f, 0.335284f, 0.316367f, 0.297450f, 0.278534f, 0.259617f, 0.240700f, 0.353045f, 0.330975f, 0.308905f, 0.286836f, 0.264766f, 0.242696f, 0.317066f, 0.300134f, 0.283201f, 0.266268f, 0.249335f, 0.232403f, 0.215470f, 0.316034f, 0.296279f, 0.276524f, 0.256769f, 0.237014f, 0.217260f, 0.279932f, 0.264983f, 0.250034f, 0.235086f, 0.220137f, 0.205189f, 0.190240f, 0.279023f, 0.261583f, 0.244143f, 0.226703f, 0.209263f, 0.191823f, 0.152046f, 0.139081f, 0.126117f, 0.113152f, 0.100188f, 0.087223f, 0.074259f, 0.151261f, 0.136136f, 0.121011f, 0.105885f, 0.090760f, 0.075635f, 0.128800f, 0.117819f, 0.106839f, 0.095859f, 0.084878f, 0.073898f, 0.062918f, 0.128139f, 0.115329f, 0.102518f, 0.089708f, 0.076898f, 0.064087f, 0.105554f, 0.096558f, 0.087561f, 0.078565f, 0.069569f, 0.060573f, 0.051577f, 0.105017f, 0.094522f, 0.084026f, 0.073531f, 0.063035f, 0.052539f, 0.082308f, 0.075296f, 0.068284f, 0.061272f, 0.054260f, 0.047248f, 0.040235f, 0.081896f, 0.073715f, 0.065534f, 0.057353f, 0.049172f, 0.040992f, 0.023866f, 0.018838f, 0.013810f, 0.008783f, 0.003755f, -0.001273f, -0.006301f, 0.023578f, 0.017712f, 0.011846f, 0.005980f, 0.000114f, -0.005752f, 0.014509f, 0.011466f, 0.008422f, 0.005378f, 0.002334f, -0.000710f, -0.003754f, 0.014345f, 0.010794f, 0.007243f, 0.003692f, 0.000140f, -0.003411f, 0.005152f, 0.004093f, 0.003033f, 0.001973f, 0.000914f, -0.000146f, -0.001206f, 0.005112f, 0.003876f, 0.002639f, 0.001403f, 0.000167f, -0.001070f, -0.004204f, -0.003280f, -0.002356f, -0.001431f, -0.000507f, 0.000418f, 0.001342f, -0.004121f, -0.003042f, -0.001964f, -0.000885f, 0.000193f, 0.001272f, 0.006798f, 0.009707f, 0.012616f, 0.015524f, 0.018433f, 0.021341f, 0.024250f, 0.007006f, 0.010399f, 0.013793f, 0.017186f, 0.020579f, 0.023973f, 0.011330f, 0.016223f, 0.021116f, 0.026008f, 0.030901f, 0.035794f, 0.040686f, 0.011662f, 0.017370f, 0.023078f, 0.028786f, 0.034494f, 0.040203f, 0.015862f, 0.022739f, 0.029616f, 0.036493f, 0.043369f, 0.050246f, 0.057123f, 0.016318f, 0.024341f, 0.032364f, 0.040387f, 0.048410f, 0.056433f, 0.020394f, 0.029255f, 0.038116f, 0.046977f, 0.055838f, 0.064699f, 0.073560f, 0.020974f, 0.031312f, 0.041649f, 0.051987f, 0.062325f, 0.072663f, 0.100842f, 0.111687f, 0.122532f, 0.133377f, 0.144222f, 0.155067f, 0.165912f, 0.101545f, 0.114198f, 0.126850f, 0.139503f, 0.152155f, 0.164808f, 0.119262f, 0.132092f, 0.144921f, 0.157750f, 0.170579f, 0.183408f, 0.196237f, 0.120090f, 0.135057f, 0.150025f, 0.164992f, 0.179960f, 0.194927f, 0.137683f, 0.152496f, 0.167310f, 0.182123f, 0.196936f, 0.211750f, 0.226563f, 0.138635f, 0.155917f, 0.173199f, 0.190481f, 0.207764f, 0.225046f, 0.156104f, 0.172901f, 0.189699f, 0.206496f, 0.223294f, 0.240091f, 0.256889f, 0.157180f, 0.176777f, 0.196374f, 0.215971f, 0.235568f, 0.255165f, 0.305996f, 0.324777f, 0.343559f, 0.362340f, 0.381122f, 0.399904f, 0.418685f, 0.307195f, 0.329107f, 0.351019f, 0.372931f, 0.394843f, 0.416755f, 0.338305f, 0.359071f, 0.379837f, 0.400603f, 0.421368f, 0.442134f, 0.462900f, 0.339629f, 0.363856f, 0.388082f, 0.412309f, 0.436536f, 0.460762f, 0.370615f, 0.393365f, 0.416115f, 0.438865f, 0.461614f, 0.484364f, 0.507114f, 0.372063f, 0.398604f, 0.425146f, 0.451687f, 0.478229f, 0.504770f, 0.402925f, 0.427659f, 0.452393f, 0.477127f, 0.501861f, 0.526595f, 0.551329f, 0.404497f, 0.433353f, 0.462209f, 0.491065f, 0.519922f, 0.548778f};
+ std::vector present_key = {0.308528f, 0.942185f, 0.888265f, 0.860311f, 0.653000f, 0.344289f, 0.548849f, 0.815225f, 0.098610f, 0.801075f, 0.041180f, 0.816421f, 0.807564f, 0.051007f, 0.627161f, 0.502453f, 0.169820f, 0.148379f, 0.773259f, 0.567693f, 0.982999f, 0.982248f, 0.992667f, 0.118616f, 0.938256f, 0.244570f, 0.458212f, 0.757407f, 0.203621f, 0.566312f, 0.185817f, 0.104736f, 0.116559f, 0.357639f, 0.004655f, 0.424854f, 0.664197f, 0.401688f, 0.085795f, 0.062689f, 0.278117f, 0.169313f, 0.965095f, 0.151230f, 0.805462f, 0.586108f, 0.569287f, 0.512081f, 0.971763f, 0.363845f, 0.787916f, 0.555294f, 0.395634f, 0.955466f, 0.598316f, 0.118917f, 0.417539f, 0.781582f, 0.693747f, 0.916340f, 0.259377f, 0.758194f, 0.459875f, 0.573610f, 0.955047f, 0.979286f, 0.861591f, 0.359097f, 0.887701f, 0.638609f, 0.429997f, 0.035743f, 0.770128f, 0.502106f, 0.786188f, 0.748023f, 0.793567f, 0.300651f, 0.800799f, 0.548846f, 0.473326f, 0.675126f, 0.021359f, 0.102317f, 0.292177f, 0.982990f, 0.139746f, 0.330596f, 0.051053f, 0.331269f, 0.320326f, 0.946807f, 0.845154f, 0.382764f, 0.024769f, 0.831031f, 0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.660536f, 0.152364f, 0.996071f, 0.100233f, 0.867115f, 0.294266f, 0.435353f, 0.795457f, 0.677508f, 0.937864f, 0.621140f, 0.097810f, 0.884360f, 0.769156f, 0.711870f, 0.053734f, 0.396223f, 0.167436f, 0.821904f, 0.700529f, 0.883078f, 0.966575f, 0.774748f, 0.994233f, 0.614770f, 0.037130f, 0.014252f, 0.342104f, 0.823472f, 0.866135f, 0.960813f, 0.065121f, 0.044571f, 0.913284f, 0.305047f, 0.557987f, 0.982445f, 0.400449f, 0.665871f, 0.400880f, 0.768195f, 0.527715f, 0.237523f, 0.271306f, 0.258059f, 0.532320f, 0.703189f, 0.949280f, 0.694087f, 0.781193f, 0.168926f, 0.374063f, 0.413780f, 0.686380f, 0.295892f, 0.303292f, 0.355889f, 0.810302f, 0.577590f, 0.075277f, 0.078246f, 0.371287f, 0.766591f, 0.688683f, 0.707982f, 0.767210f, 0.287153f, 0.548256f, 0.543353f, 0.739632f, 0.956871f, 0.277990f, 0.793282f, 0.659971f, 0.580238f, 0.774880f, 0.944032f, 0.036691f, 0.147400f, 0.756287f, 0.083791f, 0.516124f, 0.219861f, 0.274296f, 0.701840f, 0.030193f, 0.873319f, 0.444479f, 0.502393f, 0.540048f, 0.645544f, 0.344857f, 0.101107f, 0.318379f, 0.168142f, 0.556133f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.318029f, 0.958067f, 0.965734f, 0.620126f, 0.617497f, 0.985379f, 0.887283f, 0.765070f, 0.313591f, 0.365539f, 0.201267f, 0.487148f, 0.990369f, 0.912151f, 0.118349f, 0.025190f, 0.898638f, 0.537170f, 0.200190f, 0.673653f, 0.644223f, 0.122086f, 0.259600f, 0.060078f, 0.209860f, 0.132306f, 0.193236f, 0.685467f, 0.049500f, 0.101855f, 0.134174f, 0.316541f, 0.298750f, 0.255064f, 0.750537f, 0.998023f, 0.533978f, 0.944203f, 0.396610f, 0.106682f, 0.408774f, 0.296128f, 0.493407f, 0.657044f, 0.461050f, 0.935161f, 0.884765f, 0.701978f, 0.489685f, 0.131687f, 0.397014f, 0.704402f, 0.284886f, 0.103988f, 0.907898f, 0.709051f, 0.615276f, 0.792499f, 0.835646f, 0.483459f, 0.881188f, 0.916419f, 0.271551f, 0.607545f, 0.526584f, 0.537946f, 0.937663f, 0.305189f, 0.983434f, 0.902131f, 0.458723f, 0.817453f, 0.769047f, 0.677895f, 0.319834f, 0.196451f, 0.671528f, 0.842973f, 0.016253f, 0.642803f, 0.442873f, 0.898088f, 0.321473f, 0.474185f, 0.514767f, 0.140440f, 0.712892f, 0.830476f, 0.057909f, 0.291389f, 0.038045f, 0.956544f, 0.667169f, 0.964200f, 0.531494f, 0.802069f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.374414f, 0.353819f, 0.378268f, 0.657862f, 0.359453f, 0.900367f, 0.983275f, 0.030427f, 0.193623f, 0.112250f, 0.042364f, 0.227741f, 0.446793f, 0.836990f, 0.221824f, 0.493945f, 0.929619f, 0.667215f, 0.798079f, 0.550994f, 0.980466f, 0.588662f, 0.045511f, 0.197983f, 0.404774f, 0.601277f, 0.771931f, 0.413086f, 0.710058f, 0.789869f, 0.317260f, 0.979270f, 0.649656f, 0.880998f, 0.555938f, 0.741603f, 0.770544f, 0.908248f, 0.150350f, 0.558283f, 0.428379f, 0.923159f, 0.105095f, 0.982574f, 0.875451f, 0.073826f, 0.490966f, 0.717560f, 0.738152f, 0.906494f, 0.799865f, 0.310930f, 0.498435f, 0.701786f, 0.138437f, 0.193991f, 0.481042f, 0.298246f, 0.862559f, 0.586277f, 0.348665f, 0.848833f, 0.804878f, 0.998355f, 0.847308f, 0.414457f, 0.127499f, 0.840641f, 0.059758f, 0.350271f, 0.919738f, 0.960766f, 0.640565f, 0.688648f, 0.042454f, 0.514480f, 0.546868f, 0.340101f, 0.068597f, 0.228908f, 0.357984f, 0.435142f, 0.590927f, 0.722392f, 0.317632f, 0.328954f, 0.019692f, 0.040875f, 0.257822f, 0.740245f, 0.628314f, 0.769789f, 0.768919f, 0.856567f, 0.720319f, 0.979011f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.898825f, 0.586717f, 0.588158f, 0.034267f, 0.998527f, 0.131576f, 0.740347f, 0.821015f, 0.373055f, 0.196852f, 0.098760f, 0.748606f, 0.452654f, 0.713718f, 0.915408f, 0.146584f, 0.919171f, 0.411626f, 0.305267f, 0.943062f, 0.990652f, 0.198892f, 0.656838f, 0.106495f, 0.650914f, 0.827313f, 0.684499f, 0.417333f, 0.383066f, 0.393122f, 0.589712f, 0.881567f, 0.929066f, 0.053530f, 0.181622f, 0.112224f, 0.193335f, 0.346608f, 0.506532f, 0.629461f, 0.732142f, 0.890112f, 0.989088f, 0.662856f, 0.845365f, 0.778039f, 0.307532f, 0.875692f, 0.042763f, 0.000367f, 0.273733f, 0.462098f, 0.638363f, 0.101770f, 0.673010f, 0.801816f, 0.185313f, 0.415125f, 0.519985f, 0.451807f, 0.799830f, 0.960522f, 0.798953f, 0.077993f, 0.804936f, 0.066596f, 0.235970f, 0.153097f, 0.197519f, 0.528315f, 0.671690f, 0.470321f, 0.959696f, 0.240292f, 0.763140f, 0.870182f, 0.562066f, 0.456223f, 0.596184f, 0.428810f, 0.555194f, 0.416934f, 0.400470f, 0.695346f, 0.092851f, 0.166542f, 0.851198f, 0.771077f, 0.281454f, 0.377269f, 0.926027f, 0.818077f, 0.614346f, 0.221490f, 0.044252f, 0.431258f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.672627f, 0.828480f, 0.852689f, 0.032776f, 0.244157f, 0.339095f, 0.188732f, 0.802975f, 0.767466f, 0.516833f, 0.982926f, 0.144059f, 0.899652f, 0.116463f, 0.163182f, 0.696219f, 0.109570f, 0.565845f, 0.420234f, 0.728474f, 0.900675f, 0.769872f, 0.849690f, 0.032945f, 0.310196f, 0.515433f, 0.415953f, 0.231255f, 0.307874f, 0.945431f, 0.294181f, 0.353904f, 0.003710f, 0.845078f, 0.154841f, 0.204144f, 0.255265f, 0.884622f, 0.206451f, 0.797526f, 0.808049f, 0.927021f, 0.115561f, 0.217279f, 0.742898f, 0.196001f, 0.286330f, 0.166742f, 0.172697f, 0.481553f, 0.109683f, 0.321698f, 0.426594f, 0.024548f, 0.388333f, 0.094122f, 0.493579f, 0.825738f, 0.818422f, 0.080449f, 0.601228f, 0.834586f, 0.237973f, 0.761927f, 0.890764f, 0.806124f, 0.107301f, 0.009060f, 0.191724f, 0.270477f, 0.616183f, 0.384273f, 0.703407f, 0.353075f, 0.154425f, 0.312690f, 0.884324f, 0.958532f, 0.207513f, 0.788468f, 0.273349f, 0.887132f, 0.165546f, 0.665960f, 0.084211f, 0.973893f, 0.700633f, 0.841816f, 0.566669f, 0.476801f, 0.621882f, 0.528742f, 0.469384f, 0.759450f, 0.178201f, 0.171172f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f};
+ // {2, 3, 18, 8}
+ std::vector present_value = {0.431843f, 0.320748f, 0.074125f, 0.844471f, 0.771603f, 0.543921f, 0.979325f, 0.072600f, 0.766669f, 0.266370f, 0.368599f, 0.219279f, 0.789038f, 0.144240f, 0.840017f, 0.661578f, 0.059023f, 0.810982f, 0.627756f, 0.904982f, 0.748722f, 0.561121f, 0.836547f, 0.278050f, 0.546950f, 0.293617f, 0.968204f, 0.226196f, 0.015738f, 0.325855f, 0.502509f, 0.028363f, 0.559248f, 0.874283f, 0.704732f, 0.622968f, 0.955962f, 0.958279f, 0.824266f, 0.607742f, 0.487765f, 0.013316f, 0.606262f, 0.989088f, 0.818101f, 0.340605f, 0.152047f, 0.784059f, 0.743938f, 0.967047f, 0.874842f, 0.555663f, 0.101284f, 0.483501f, 0.313695f, 0.512408f, 0.301702f, 0.861823f, 0.844327f, 0.315465f, 0.599581f, 0.430181f, 0.909093f, 0.187361f, 0.697728f, 0.970375f, 0.175276f, 0.201966f, 0.693723f, 0.779154f, 0.490549f, 0.609686f, 0.212682f, 0.476614f, 0.112072f, 0.321422f, 0.284780f, 0.444625f, 0.930126f, 0.181268f, 0.401388f, 0.615597f, 0.946557f, 0.133148f, 0.917877f, 0.081054f, 0.480741f, 0.454590f, 0.209603f, 0.347460f, 0.454165f, 0.865211f, 0.955064f, 0.518926f, 0.870100f, 0.608172f, 0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.349087f, 0.194194f, 0.413135f, 0.522824f, 0.044443f, 0.145841f, 0.600184f, 0.225002f, 0.837326f, 0.326942f, 0.104834f, 0.083531f, 0.937123f, 0.118020f, 0.140910f, 0.862666f, 0.254288f, 0.665951f, 0.816726f, 0.607181f, 0.957489f, 0.708883f, 0.112752f, 0.558410f, 0.718186f, 0.801957f, 0.026321f, 0.718879f, 0.825681f, 0.746834f, 0.512349f, 0.458021f, 0.549419f, 0.704644f, 0.922914f, 0.617035f, 0.887834f, 0.701257f, 0.068336f, 0.500828f, 0.286486f, 0.285175f, 0.355928f, 0.314733f, 0.578610f, 0.683601f, 0.268749f, 0.129763f, 0.058809f, 0.575753f, 0.186130f, 0.009248f, 0.927753f, 0.537140f, 0.092448f, 0.842921f, 0.983203f, 0.448601f, 0.042490f, 0.117546f, 0.381654f, 0.885523f, 0.148039f, 0.823990f, 0.014976f, 0.457389f, 0.644397f, 0.060379f, 0.614763f, 0.944404f, 0.160260f, 0.729611f, 0.609094f, 0.185116f, 0.006203f, 0.009284f, 0.532092f, 0.942779f, 0.644299f, 0.714300f, 0.493865f, 0.581889f, 0.126368f, 0.876821f, 0.760793f, 0.998199f, 0.297723f, 0.227018f, 0.125162f, 0.964210f, 0.780885f, 0.166325f, 0.552686f, 0.413768f, 0.151486f, 0.162073f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f, 0.963470f, 0.304964f, 0.941439f, 0.075611f, 0.460803f, 0.129619f, 0.004787f, 0.553766f, 0.113894f, 0.722025f, 0.698116f, 0.176333f, 0.941742f, 0.721043f, 0.297970f, 0.709234f, 0.731930f, 0.342226f, 0.375589f, 0.359107f, 0.616618f, 0.900410f, 0.173193f, 0.875200f, 0.027653f, 0.660339f, 0.414439f, 0.791282f, 0.721198f, 0.480108f, 0.643864f, 0.501773f, 0.811518f, 0.476084f, 0.523156f, 0.250521f, 0.605043f, 0.302905f, 0.577284f, 0.169678f, 0.159469f, 0.417030f, 0.426820f, 0.268109f, 0.131597f, 0.039211f, 0.025232f, 0.271550f, 0.461853f, 0.726243f, 0.474872f, 0.904051f, 0.035220f, 0.180661f, 0.338515f, 0.577496f, 0.852736f, 0.350202f, 0.267989f, 0.061889f, 0.821303f, 0.379666f, 0.571550f, 0.983555f, 0.001595f, 0.145450f, 0.779111f, 0.805128f, 0.769247f, 0.536999f, 0.978857f, 0.396185f, 0.601944f, 0.063369f, 0.409857f, 0.722500f, 0.238739f, 0.943828f, 0.686783f, 0.287575f, 0.768999f, 0.083165f, 0.974774f, 0.049285f, 0.933456f, 0.252854f, 0.757824f, 0.000074f, 0.254240f, 0.749101f, 0.532336f, 0.114952f, 0.393630f, 0.375549f, 0.568162f, 0.667977f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.840830f, 0.497231f, 0.392022f, 0.143977f, 0.804823f, 0.713370f, 0.408677f, 0.518432f, 0.665183f, 0.164806f, 0.027198f, 0.317504f, 0.595585f, 0.486606f, 0.692555f, 0.819690f, 0.488442f, 0.134267f, 0.850628f, 0.574990f, 0.739937f, 0.704665f, 0.968212f, 0.295307f, 0.705307f, 0.365676f, 0.395411f, 0.230595f, 0.344010f, 0.948297f, 0.292571f, 0.245991f, 0.583138f, 0.258036f, 0.473386f, 0.834176f, 0.230400f, 0.426691f, 0.610490f, 0.545629f, 0.974723f, 0.680370f, 0.739946f, 0.966956f, 0.414438f, 0.355380f, 0.043862f, 0.184204f, 0.237190f, 0.183504f, 0.754784f, 0.535883f, 0.667634f, 0.820462f, 0.230774f, 0.325924f, 0.708360f, 0.392759f, 0.029271f, 0.434955f, 0.908273f, 0.409021f, 0.332249f, 0.989525f, 0.644416f, 0.365998f, 0.102020f, 0.787849f, 0.708075f, 0.921916f, 0.217276f, 0.114924f, 0.724073f, 0.203396f, 0.176104f, 0.319807f, 0.816825f, 0.539537f, 0.045850f, 0.463895f, 0.683980f, 0.538368f, 0.572450f, 0.224777f, 0.847739f, 0.561399f, 0.713246f, 0.981864f, 0.428199f, 0.881067f, 0.007281f, 0.033407f, 0.590280f, 0.311449f, 0.248277f, 0.277935f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.318403f, 0.728948f, 0.569196f, 0.789036f, 0.830197f, 0.842935f, 0.414644f, 0.421273f, 0.926266f, 0.661764f, 0.080467f, 0.542187f, 0.356007f, 0.987435f, 0.013655f, 0.612181f, 0.723623f, 0.288907f, 0.973642f, 0.859537f, 0.915653f, 0.019232f, 0.569872f, 0.294650f, 0.849029f, 0.632850f, 0.538877f, 0.114588f, 0.540223f, 0.631904f, 0.955912f, 0.585051f, 0.967401f, 0.961606f, 0.650200f, 0.505908f, 0.466022f, 0.890379f, 0.028257f, 0.113808f, 0.102072f, 0.756935f, 0.339651f, 0.637969f, 0.603783f, 0.385828f, 0.531568f, 0.645139f, 0.940950f, 0.575634f, 0.614367f, 0.067856f, 0.952216f, 0.528082f, 0.801273f, 0.050291f, 0.420910f, 0.256975f, 0.266976f, 0.791454f, 0.623867f, 0.439745f, 0.010586f, 0.964928f, 0.962023f, 0.217552f, 0.041346f, 0.530199f, 0.951411f, 0.910396f, 0.584663f, 0.303549f, 0.329961f, 0.897914f, 0.491784f, 0.131116f, 0.248425f, 0.276795f, 0.123547f, 0.463044f, 0.916051f, 0.668783f, 0.072474f, 0.005495f, 0.276248f, 0.362693f, 0.776750f, 0.967006f, 0.387567f, 0.686690f, 0.994902f, 0.745667f, 0.636190f, 0.078075f, 0.323215f, 0.913392f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.201005f, 0.843590f, 0.696324f, 0.366324f, 0.529174f, 0.542806f, 0.714054f, 0.516556f, 0.133076f, 0.773455f, 0.406273f, 0.963094f, 0.283514f, 0.263079f, 0.333507f, 0.572317f, 0.894870f, 0.176282f, 0.279679f, 0.581680f, 0.454334f, 0.447323f, 0.820734f, 0.923878f, 0.481307f, 0.687352f, 0.801059f, 0.518366f, 0.294316f, 0.638085f, 0.585109f, 0.901563f, 0.052407f, 0.910131f, 0.534432f, 0.015676f, 0.344702f, 0.724334f, 0.488433f, 0.980159f, 0.422610f, 0.326635f, 0.821672f, 0.547907f, 0.682327f, 0.805702f, 0.671428f, 0.422408f, 0.124796f, 0.580248f, 0.897433f, 0.418892f, 0.910725f, 0.503528f, 0.620842f, 0.832989f, 0.564597f, 0.090969f, 0.980979f, 0.245849f, 0.710505f, 0.505113f, 0.478773f, 0.243941f, 0.722151f, 0.112788f, 0.990453f, 0.845374f, 0.534509f, 0.424553f, 0.286465f, 0.501591f, 0.879417f, 0.275006f, 0.500537f, 0.234550f, 0.337149f, 0.190261f, 0.990539f, 0.571497f, 0.732815f, 0.098250f, 0.366118f, 0.892640f, 0.084438f, 0.165483f, 0.625418f, 0.622789f, 0.838227f, 0.935493f, 0.141986f, 0.259374f, 0.427461f, 0.000903f, 0.069814f, 0.226491f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f};
+ // {2, 3, 4, 18}
+ constexpr float inff = std::numeric_limits::infinity();
+ std::vector qk_matmul = {2.137658f, 1.567682f, 1.582827f, 0.953936f, 0.636597f, 1.001645f, 1.885707f, 1.361086f, 1.495408f, 1.566455f, 1.459078f, 1.668413f, 0.904174f, -inff, -inff, -inff, -inff, -inff, 1.229267f, 0.591855f, 1.372683f, 0.964445f, 1.006092f, 1.046331f, 1.712052f, 1.060710f, 2.141520f, 1.917742f, 1.063752f, 0.892409f, 0.884336f, 0.881352f, -inff, -inff, -inff, -inff, 2.235662f, 1.742821f, 2.198921f, 1.079357f, 1.510221f, 1.812315f, 1.396341f, 1.864746f, 1.498768f, 2.115730f, 0.844762f, 1.323617f, 1.096593f, 1.033003f, 1.868677f, -inff, -inff, -inff, 1.429269f, 0.876355f, 0.928405f, 1.469794f, 0.649940f, 1.435654f, 1.452830f, 1.053687f, 1.338220f, 0.966775f, 1.237266f, 1.488850f, 1.438267f, 0.931250f, 1.633272f, 0.944889f, -inff, -inff, 1.172613f, 1.105815f, 1.263303f, 1.702161f, 1.406517f, 1.808470f, 1.496128f, 1.169961f, 1.428707f, 1.393064f, 1.624670f, 1.287919f, 0.674733f, -inff, -inff, -inff, -inff, -inff, 0.838456f, 1.191558f, 1.771291f, 1.491907f, 0.911088f, 0.865799f, 1.154893f, 1.472593f, 0.826140f, 0.896018f, 1.281853f, 0.942941f, 1.470656f, 0.816028f, -inff, -inff, -inff, -inff, 1.133820f, 1.086309f, 1.712385f, 1.254675f, 1.427773f, 0.748848f, 1.056134f, 1.187805f, 1.419181f, 1.140224f, 1.269629f, 1.135934f, 0.694738f, 1.528325f, 0.959286f, -inff, -inff, -inff, 1.160321f, 1.097000f, 1.485019f, 1.111147f, 0.836961f, 0.948765f, 1.234762f, 0.835082f, 0.833382f, 0.589928f, 1.266538f, 1.303439f, 0.622733f, 0.837537f, 0.605730f, 0.730216f, -inff, -inff, 2.078597f, 0.610472f, 1.371772f, 0.794857f, 1.018924f, 1.165257f, 1.466839f, 1.206415f, 1.662507f, 1.098436f, 1.283408f, 1.533854f, 1.247966f, -inff, -inff, -inff, -inff, -inff, 1.707491f, 0.439978f, 0.919238f, 0.297115f, 0.982817f, 1.370520f, 0.766707f, 0.938981f, 1.095468f, 1.442393f, 0.742909f, 0.529869f, 0.628822f, 1.353301f, -inff, -inff, -inff, -inff, 1.483284f, 1.334536f, 0.757364f, 1.243801f, 0.767143f, 0.919318f, 0.693929f, 1.000990f, 1.107699f, 1.001247f, 1.434079f, 1.522769f, 0.696104f, 1.336034f, 0.501240f, -inff, -inff, -inff, 1.535892f, 1.342303f, 0.701559f, 1.211220f, 1.510985f, 0.961962f, 1.471503f, 1.440467f, 1.835586f, 0.947043f, 1.254547f, 1.009386f, 0.842613f, 1.508191f, 1.233544f, 1.280385f, -inff, -inff, 1.552432f, 0.958768f, 1.676495f, 1.810273f, 1.019336f, 1.487615f, 0.695035f, 1.391893f, 1.060641f, 0.917107f, 1.115109f, 1.128137f, 0.986429f, -inff, -inff, -inff, -inff, -inff, 1.289288f, 1.303667f, 0.882238f, 1.948027f, 1.580638f, 0.863439f, 1.059965f, 2.095325f, 1.493638f, 0.654104f, 0.828719f, 1.673449f, 0.479778f, 1.149678f, -inff, -inff, -inff, -inff, 1.177682f, 1.225590f, 1.735621f, 2.114078f, 1.905758f, 1.835981f, 1.432170f, 1.444457f, 2.016032f, 0.762211f, 1.059737f, 1.378216f, 1.564930f, 1.950097f, 1.598798f, -inff, -inff, -inff, 0.820477f, 0.962096f, 1.188223f, 1.264395f, 1.676953f, 1.487113f, 0.962162f, 1.377522f, 1.370079f, 1.450785f, 1.131087f, 1.962317f, 0.764849f, 0.777860f, 1.194763f, 1.030136f, -inff, -inff, 1.096708f, 1.345589f, 1.404595f, 1.370459f, 1.263369f, 1.364863f, 0.489623f, 0.596189f, 1.079480f, 0.915348f, 0.770954f, 1.548047f, 1.519504f, -inff, -inff, -inff, -inff, -inff, 1.856943f, 0.790590f, 1.235241f, 2.061177f, 1.282346f, 1.896653f, 1.112410f, 1.622862f, 0.780625f, 1.990919f, 1.693934f, 1.466544f, 1.026297f, 1.323339f, -inff, -inff, -inff, -inff, 1.778816f, 1.746915f, 1.169870f, 1.847628f, 0.729303f, 2.421048f, 1.266061f, 1.481203f, 1.016384f, 2.038725f, 1.132054f, 1.669076f, 1.958931f, 1.654780f, 1.644111f, -inff, -inff, -inff, 0.856287f, 1.124803f, 1.216201f, 0.831110f, 0.761234f, 1.204141f, 0.994307f, 0.832859f, 1.294077f, 1.566637f, 1.102631f, 1.472731f, 1.569911f, 0.779225f, 1.536189f, 1.277889f, -inff, -inff, 0.944230f, 1.585174f, 1.001532f, 0.973579f, 1.652668f, 1.112330f, 1.052878f, 1.326390f, 1.526319f, 1.790060f, 1.219317f, 1.742865f, 0.871467f, -inff, -inff, -inff, -inff, -inff, 0.794245f, 1.084904f, 0.813691f, 1.037344f, 0.254175f, 1.071614f, 0.477497f, 0.773591f, 1.317670f, 1.382451f, 0.759806f, 1.228428f, 0.583565f, 1.274037f, -inff, -inff, -inff, -inff, 0.865060f, 0.697643f, 1.300273f, 1.064195f, 1.435744f, 1.516307f, 0.626589f, 1.255387f, 1.115037f, 1.202643f, 1.789729f, 1.328769f, 1.046150f, 1.149905f, 1.696396f, -inff, -inff, -inff, 1.421552f, 1.324626f, 1.029005f, 0.960238f, 1.215132f, 1.450928f, 1.351898f, 1.718175f, 1.502146f, 1.736591f, 1.019685f, 1.130950f, 1.097223f, 1.330517f, 1.675029f, 1.069868f, -inff, -inff};
ASSERT_EQ(y.size(), batch_size * q_num_heads * q_sequence_length * v_head_size);
ASSERT_EQ(present_key.size(), batch_size * kv_num_heads * (past_sequence_length + kv_sequence_length) * head_size);
ASSERT_EQ(present_value.size(), batch_size * kv_num_heads * (past_sequence_length + kv_sequence_length) * v_head_size);
@@ -1196,7 +1230,7 @@ TEST(AttentionTest, Attention4DWithMask4DPastAndPresentQkMatmul) {
RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length,
q, k, v, m, std::initializer_list(), past_key, past_value,
- -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type
+ 1, 1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type
y, present_key, present_value, qk_matmul,
false, true, true // disable_cpu, disable_cuda, disable_dml
);
diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_ep_static_kernels.cc b/onnxruntime/test/providers/internal_testing/internal_testing_ep_static_kernels.cc
index ea6b3f148979f..b0ee078335308 100644
--- a/onnxruntime/test/providers/internal_testing/internal_testing_ep_static_kernels.cc
+++ b/onnxruntime/test/providers/internal_testing/internal_testing_ep_static_kernels.cc
@@ -7,14 +7,12 @@
#include "core/framework/utils.h"
#include "core/providers/cpu/nn/conv_attributes.h"
#include "core/providers/utils.h"
+#include "test/providers/internal_testing/internal_testing_execution_provider.h"
namespace onnxruntime {
namespace internal_testing_ep {
-// can't use 'utils::kInternalTestingExecutionProvider' in the macro so redefine here to a name without '::'
-constexpr const char* internal_testing_ep = utils::kInternalTestingExecutionProvider;
-
-ONNX_OPERATOR_KERNEL_EX(Conv, kMSInternalNHWCDomain, 11, internal_testing_ep,
+ONNX_OPERATOR_KERNEL_EX(Conv, kMSInternalNHWCDomain, 11, kInternalTestingExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()),
Conv);
diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc
index 390329c5cae7a..934916ea862e9 100644
--- a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc
+++ b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc
@@ -26,17 +26,15 @@ namespace internal_testing_ep {
// NHWC Conv requires contrib ops
#if !defined(ORT_MINIMAL_BUILD) && !defined(DISABLE_CONTRIB_OPS)
-// the 'utils::' breaks the kernel registration macros
-constexpr const char* internal_testing_ep = utils::kInternalTestingExecutionProvider;
-class ONNX_OPERATOR_KERNEL_CLASS_NAME(internal_testing_ep, kMSInternalNHWCDomain, 11, Conv);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kInternalTestingExecutionProvider, kMSInternalNHWCDomain, 11, Conv);
// register static kernels we have implementations for
static std::unique_ptr RegisterKernels() {
auto kernel_registry = std::make_unique();
ORT_THROW_IF_ERROR(kernel_registry->Register(
- BuildKernelCreateInfo()));
return kernel_registry;
@@ -68,7 +66,7 @@ void RegisterDummyStaticKernel(KernelRegistry& registry, const Node& node) {
builder.SetName(node.OpType())
.SetDomain(node.Domain())
.SinceVersion(node.SinceVersion())
- .Provider(internal_testing_ep);
+ .Provider(kInternalTestingExecutionProvider);
ORT_THROW_IF_ERROR(registry.Register(builder, DummyCreateKernel));
}
@@ -85,7 +83,7 @@ constexpr const char* INTERNAL_TESTING_EP = "InternalTestingEP";
InternalTestingExecutionProvider::InternalTestingExecutionProvider(const std::unordered_set& ops,
const std::unordered_set& stop_ops,
DataLayout preferred_layout)
- : IExecutionProvider{utils::kInternalTestingExecutionProvider},
+ : IExecutionProvider{kInternalTestingExecutionProvider},
ep_name_{INTERNAL_TESTING_EP},
ops_{ops},
stop_ops_{stop_ops},
@@ -221,7 +219,7 @@ InternalTestingExecutionProvider::GetCapability(const onnxruntime::GraphViewer&
auto compile_capabilities = utils::CreateSupportedPartitions(graph_viewer, supported_compiled_nodes, stop_ops_,
generate_metadef_name, ep_name_,
- onnxruntime::utils::kInternalTestingExecutionProvider,
+ kInternalTestingExecutionProvider,
/*QDQ NodeUnit map*/ nullptr,
debug_output_);
diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h
index 0caa0febc2796..8832265798798 100644
--- a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h
+++ b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h
@@ -9,6 +9,9 @@
namespace onnxruntime {
namespace internal_testing_ep {
+// Provider type of `InternalTestingExecutionProvider`, an EP used for internal testing.
+constexpr const char* kInternalTestingExecutionProvider = "InternalTestingExecutionProvider";
+
class InternalTestingExecutionProvider : public IExecutionProvider {
public:
InternalTestingExecutionProvider(const std::unordered_set& ops,
diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_partitioning_tests.cc b/onnxruntime/test/providers/internal_testing/internal_testing_partitioning_tests.cc
index d58db5178032d..c085d1acd10c0 100644
--- a/onnxruntime/test/providers/internal_testing/internal_testing_partitioning_tests.cc
+++ b/onnxruntime/test/providers/internal_testing/internal_testing_partitioning_tests.cc
@@ -57,7 +57,7 @@ auto RunTest(const std::string& op, const ORTCHAR_T* model_path) {
for (const auto& node : graph.Nodes()) {
EXPECT_EQ(supported_ops.count(node.OpType()), size_t(0))
<< "Nodes with supported op types should have been replaced. Node with type " << node.OpType() << " was not.";
- if (node.GetExecutionProviderType() == utils::kInternalTestingExecutionProvider) {
+ if (node.GetExecutionProviderType() == kInternalTestingExecutionProvider) {
EXPECT_STATUS_OK(func_mgr.GetFuncs(node.Name(), compute_func));
EXPECT_NE(compute_func, nullptr);
++num_partitions;
@@ -116,7 +116,7 @@ TEST(InternalTestingEP, TestDependenciesCorrectlyHandled) {
for (const auto& node : graph.Nodes()) {
EXPECT_EQ(supported_ops.count(node.OpType()), size_t(0))
<< "Nodes with supported op types should have been replaced. Node with type " << node.OpType() << " was not.";
- if (node.GetExecutionProviderType() == utils::kInternalTestingExecutionProvider) {
+ if (node.GetExecutionProviderType() == kInternalTestingExecutionProvider) {
EXPECT_STATUS_OK(func_mgr.GetFuncs(node.Name(), compute_func));
EXPECT_NE(compute_func, nullptr);
++num_partitions;
@@ -227,7 +227,7 @@ static void TestNnapiPartitioning(const std::string& test_name, const std::strin
std::string unsupported_op_str;
for (const Node& node : graph.Nodes()) {
- if (node.GetExecutionProviderType() != utils::kInternalTestingExecutionProvider &&
+ if (node.GetExecutionProviderType() != kInternalTestingExecutionProvider &&
ops.count(node.OpType()) == 0) {
auto entry = unsupported_ops.find(node.OpType());
if (entry != unsupported_ops.end()) {
@@ -288,12 +288,12 @@ static void TestNnapiPartitioning(const std::string& test_name, const std::strin
<< "Nodes with supported op types should have been replaced. Node with type "
<< node.OpType() << " was not.";
} else {
- EXPECT_NE(node.GetExecutionProviderType(), utils::kInternalTestingExecutionProvider)
+ EXPECT_NE(node.GetExecutionProviderType(), kInternalTestingExecutionProvider)
<< "Node is downstream from a 'stop at' node and should not have been taken. Node:"
<< node.Name();
}
- if (node.GetExecutionProviderType() == utils::kInternalTestingExecutionProvider) {
+ if (node.GetExecutionProviderType() == kInternalTestingExecutionProvider) {
EXPECT_STATUS_OK(func_mgr.GetFuncs(node.Name(), compute_func));
EXPECT_NE(compute_func, nullptr);
++stats.num_compiled_nodes;
diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc b/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc
index 275f29fdd9073..94e60739c3ccf 100644
--- a/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc
+++ b/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc
@@ -334,7 +334,7 @@ TEST(InternalTestingEP, TestLoadOrtModelWithReducedOpCoverage) {
for (const auto& node : graph.Nodes()) {
EXPECT_EQ(supported_ops.count(node.OpType()), size_t(0))
<< "Nodes with supported op types should have been replaced. Node with type " << node.OpType() << " was not.";
- if (node.GetExecutionProviderType() == utils::kInternalTestingExecutionProvider) {
+ if (node.GetExecutionProviderType() == kInternalTestingExecutionProvider) {
EXPECT_STATUS_OK(func_mgr.GetFuncs(node.Name(), compute_func));
EXPECT_NE(compute_func, nullptr);
EXPECT_THAT(node.OpType(), ::testing::StartsWith(expected_op_type_prefix));
@@ -353,7 +353,7 @@ static int CountAndValidateAssignedNodes(const Graph& current_graph,
for (const auto& node : current_graph.Nodes()) {
EXPECT_EQ(supported_ops.count(node.OpType()), size_t(0))
<< "Nodes with supported op types should have been replaced. Node with type " << node.OpType() << " was not.";
- if (node.GetExecutionProviderType() == utils::kInternalTestingExecutionProvider) {
+ if (node.GetExecutionProviderType() == kInternalTestingExecutionProvider) {
const NodeComputeInfo* compute_func = nullptr;
EXPECT_STATUS_OK(func_mgr.GetFuncs(node.Name(), compute_func));
EXPECT_NE(compute_func, nullptr);
diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py
index 1820664e1d604..b85030b46e94d 100644
--- a/onnxruntime/test/python/onnxruntime_test_python.py
+++ b/onnxruntime/test/python/onnxruntime_test_python.py
@@ -689,15 +689,30 @@ def test_run_model_with_optional_sequence_input(self):
def test_run_model(self):
sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=available_providers)
x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32)
- input_name = sess.get_inputs()[0].name
- self.assertEqual(input_name, "X")
- input_shape = sess.get_inputs()[0].shape
- self.assertEqual(input_shape, [3, 2])
- output_name = sess.get_outputs()[0].name
- self.assertEqual(output_name, "Y")
- output_shape = sess.get_outputs()[0].shape
- self.assertEqual(output_shape, [3, 2])
- res = sess.run([output_name], {input_name: x})
+
+ inputs = sess.get_inputs()
+ self.assertEqual(len(inputs), 1)
+ self.assertEqual(inputs[0].name, "X")
+ self.assertEqual(inputs[0].shape, [3, 2])
+
+ input_meminfos = sess.get_input_memory_infos()
+ self.assertEqual(len(input_meminfos), 1)
+ self.assertIsNotNone(input_meminfos[0])
+
+ input_epdevices = sess.get_input_epdevices()
+ # The entry my be None (null) but it should be present
+ self.assertEqual(len(input_epdevices), 1)
+
+ outputs = sess.get_outputs()
+ self.assertEqual(len(outputs), 1)
+ self.assertEqual(outputs[0].name, "Y")
+ self.assertEqual(outputs[0].shape, [3, 2])
+
+ output_meminfos = sess.get_output_memory_infos()
+ self.assertEqual(len(output_meminfos), 1)
+ self.assertIsNotNone(output_meminfos[0])
+
+ res = sess.run([outputs[0].name], {inputs[0].name: x})
output_expected = np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32)
np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08)
@@ -1584,6 +1599,44 @@ def test_run_model_with_cuda_copy_stream(self):
for _iteration in range(100000):
session.run(output_names=["output"], input_feed={"shape": shape})
+ def test_ort_device(self):
+ cpu_device = onnxrt.OrtDevice.make("cpu", 0)
+ self.assertEqual(cpu_device.device_id(), 0)
+ self.assertEqual(cpu_device.device_type(), 0)
+ self.assertEqual(cpu_device.device_vendor_id(), 0)
+ self.assertEqual(cpu_device.device_mem_type(), 0)
+
+ def test_ort_memory_info(self):
+ cpu_memory_info = onnxrt.OrtMemoryInfo(
+ "Cpu",
+ onnxrt.OrtAllocatorType.ORT_ARENA_ALLOCATOR,
+ 0,
+ onnxrt.OrtMemType.DEFAULT,
+ )
+ self.assertEqual(cpu_memory_info.name, "Cpu")
+ self.assertEqual(cpu_memory_info.device_id, 0)
+ self.assertEqual(cpu_memory_info.mem_type, onnxrt.OrtMemType.DEFAULT)
+ self.assertEqual(cpu_memory_info.allocator_type, onnxrt.OrtAllocatorType.ORT_ARENA_ALLOCATOR)
+ self.assertEqual(cpu_memory_info.device_mem_type, onnxrt.OrtDeviceMemoryType.DEFAULT)
+ self.assertEqual(cpu_memory_info.device_vendor_id, 0)
+
+ def test_ort_memory_info_create_v2(self):
+ cpu_memory_info = onnxrt.OrtMemoryInfo.create_v2(
+ "Test",
+ onnxrt.OrtMemoryInfoDeviceType.CPU,
+ 0, # vendor_id
+ 0, # device_id
+ onnxrt.OrtDeviceMemoryType.DEFAULT,
+ 128, # alignment
+ onnxrt.OrtAllocatorType.ORT_ARENA_ALLOCATOR,
+ )
+ self.assertEqual(cpu_memory_info.name, "Test")
+ self.assertEqual(cpu_memory_info.device_id, 0)
+ self.assertEqual(cpu_memory_info.mem_type, onnxrt.OrtMemType.DEFAULT)
+ self.assertEqual(cpu_memory_info.allocator_type, onnxrt.OrtAllocatorType.ORT_ARENA_ALLOCATOR)
+ self.assertEqual(cpu_memory_info.device_mem_type, onnxrt.OrtDeviceMemoryType.DEFAULT)
+ self.assertEqual(cpu_memory_info.device_vendor_id, 0)
+
def test_shared_allocator_using_create_and_register_allocator(self):
# Create and register an arena based allocator
diff --git a/onnxruntime/test/python/onnxruntime_test_python_autoep.py b/onnxruntime/test/python/onnxruntime_test_python_autoep.py
index cb31627a87c48..d6281d165c053 100644
--- a/onnxruntime/test/python/onnxruntime_test_python_autoep.py
+++ b/onnxruntime/test/python/onnxruntime_test_python_autoep.py
@@ -226,6 +226,14 @@ def test_example_plugin_ep_devices(self):
hw_metadata = hw_device.metadata
self.assertGreater(len(hw_metadata), 0) # Should have at least SPDRP_HARDWAREID on Windows
+ test_mem_info = test_ep_device.memory_info(onnxrt.OrtDeviceMemoryType.DEFAULT)
+ self.assertIsNotNone(test_mem_info)
+ del test_mem_info
+
+ test_sync_stream = test_ep_device.create_sync_stream()
+ self.assertIsNotNone(test_sync_stream)
+ del test_sync_stream
+
# Add EP plugin's OrtEpDevice to the SessionOptions.
sess_options = onnxrt.SessionOptions()
sess_options.add_provider_for_devices([test_ep_device], {"opt1": "val1"})
@@ -282,6 +290,55 @@ def test_example_plugin_ep_data_transfer(self):
self.unregister_execution_provider_library(ep_name)
+ def test_copy_tensors(self):
+ """
+ Test global api copy_tensors between OrtValue objects
+ using EP plug-in data transfer
+ """
+ if sys.platform != "win32":
+ self.skipTest("Skipping test because device discovery is only supported on Windows")
+
+ ep_lib_path = "example_plugin_ep.dll"
+ try:
+ ep_lib_path = get_name("example_plugin_ep.dll")
+ except FileNotFoundError:
+ self.skipTest(f"Skipping test because EP library '{ep_lib_path}' cannot be found")
+
+ ep_name = "example_ep"
+ self.register_execution_provider_library(ep_name, os.path.realpath(ep_lib_path))
+
+ # Generate 2 numpy arrays
+ a = np.random.rand(3, 2).astype(np.float32)
+ b = np.random.rand(3, 2).astype(np.float32)
+
+ # Create OrtValue from numpy arrays on EP device
+ # the example EP pretends to use GPU memory, so we place it there
+ a_device = onnxrt.OrtValue.ortvalue_from_numpy(a, "gpu", 0, 0xBE57)
+ b_device = onnxrt.OrtValue.ortvalue_from_numpy(b, "gpu", 0, 0xBE57)
+
+ # Create destination ort values with the same shape on CPU
+ a_cpu_copy = onnxrt.OrtValue.ortvalue_from_shape_and_type(a.shape, a.dtype)
+ b_cpu_copy = onnxrt.OrtValue.ortvalue_from_shape_and_type(b.shape, b.dtype)
+
+ # source list
+ src_list = [a_device, b_device]
+ dst_list = [a_cpu_copy, b_cpu_copy]
+ # Passing None for stream as we copy between CPU
+ # Test None because it is allowed
+ onnxrt.copy_tensors(src_list, dst_list, None)
+
+ # Release the OrtValue on the EP device
+ # before the EP library is unregistered
+ del src_list
+ del a_device
+ del b_device
+
+ # Verify the contents
+ np.testing.assert_array_equal(a, a_cpu_copy.numpy())
+ np.testing.assert_array_equal(b, b_cpu_copy.numpy())
+
+ self.unregister_execution_provider_library(ep_name)
+
if __name__ == "__main__":
unittest.main(verbosity=1)
diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc
index b7a9da8e1b658..5199730ae323d 100644
--- a/onnxruntime/test/shared_lib/test_inference.cc
+++ b/onnxruntime/test/shared_lib/test_inference.cc
@@ -494,6 +494,35 @@ INSTANTIATE_TEST_SUITE_P(CApiTestWithProviders,
CApiTestWithProvider,
::testing::Values(0, 1, 2, 3, 4));
+TEST(CApiTest, DISABLED_TestInputPassThroughToOutput) {
+ const ORTCHAR_T* model_uri = TSTR("testdata/input_propagated_to_output.onnx");
+ Ort::Session session(*ort_env, model_uri, Ort::SessionOptions{});
+ auto inputs_meminfos = session.GetMemoryInfoForInputs();
+ ASSERT_EQ(1U, inputs_meminfos.size());
+ auto inputs_epdevices = session.GetEpDeviceForInputs();
+ ASSERT_EQ(1U, inputs_epdevices.size());
+ auto outputs_meminfos = session.GetMemoryInfoForOutputs();
+ ASSERT_EQ(7U, outputs_meminfos.size());
+}
+
+TEST(CApiTest, DISABLED_TestDanglingInput) {
+ // Here we test an issue with segments_ids that is an input not consumed by anything
+ // This kind of model is unlikely to be used in practice but we want to make sure it works
+ const ORTCHAR_T* model_uri = TSTR("testdata/test_dangling_input_segment_ids.onnx");
+ Ort::Session session(*ort_env, model_uri, Ort::SessionOptions{});
+ auto inputs_meminfos = session.GetMemoryInfoForInputs();
+ ASSERT_EQ(2U, inputs_meminfos.size());
+ auto outputs_meminfos = session.GetMemoryInfoForOutputs();
+ ASSERT_EQ(2U, outputs_meminfos.size());
+ auto inputs_epdevices = session.GetEpDeviceForInputs();
+ ASSERT_EQ(2U, inputs_epdevices.size());
+ // One of the devices returning is null since the input is not consumed
+ // there is not a device for it.
+ const bool null_present = std::any_of(inputs_epdevices.begin(), inputs_epdevices.end(),
+ [](const auto& device) { return device == nullptr; });
+ ASSERT_TRUE(null_present);
+}
+
#if !defined(DISABLE_SPARSE_TENSORS)
TEST(CApiTest, SparseOutputModel) {
std::vector dense_shape{3, 3};
@@ -505,7 +534,15 @@ TEST(CApiTest, SparseOutputModel) {
std::vector ort_inputs;
std::vector input_names;
const char* const output_names[] = {"values"};
+ // This model produces a sparse output from a constant sparse initializer
Ort::Session session(*ort_env, SPARSE_OUTPUT_MODEL_URI, Ort::SessionOptions{});
+ auto inputs_meminfos = session.GetMemoryInfoForInputs();
+ ASSERT_TRUE(inputs_meminfos.empty());
+ auto outputs_meminfos = session.GetMemoryInfoForOutputs();
+ ASSERT_EQ(1U, outputs_meminfos.size());
+ auto inputs_epdevices = session.GetEpDeviceForInputs();
+ ASSERT_TRUE(inputs_epdevices.empty());
+
auto ort_outputs = session.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(),
output_names, 1);
ASSERT_EQ(ort_outputs.size(), 1U);
diff --git a/onnxruntime/test/testdata/add_mul_add.onnx b/onnxruntime/test/testdata/add_mul_add.onnx
new file mode 100644
index 0000000000000..0e2bc1bb9cff9
--- /dev/null
+++ b/onnxruntime/test/testdata/add_mul_add.onnx
@@ -0,0 +1,28 @@
+
+:´
+
+A
+B
+add_outputadd_0"Add
+'
+
+add_output
+B
+mul_outputmul_0"Mul
+
+
+mul_output
+ACadd_1"Add
+Main_graphZ
+A
+
+
+Z
+B
+
+
+b
+C
+
+
+B
\ No newline at end of file
diff --git a/onnxruntime/test/testdata/add_mul_add.py b/onnxruntime/test/testdata/add_mul_add.py
new file mode 100644
index 0000000000000..c22a176e065dd
--- /dev/null
+++ b/onnxruntime/test/testdata/add_mul_add.py
@@ -0,0 +1,37 @@
+from onnx import TensorProto, checker, helper, save
+
+# (A + B) * B + A
+graph_proto = helper.make_graph(
+ nodes=[
+ helper.make_node(
+ "Add",
+ inputs=["A", "B"],
+ outputs=["add_output"],
+ name="add_0",
+ ),
+ helper.make_node(
+ "Mul",
+ inputs=["add_output", "B"],
+ outputs=["mul_output"],
+ name="mul_0",
+ ),
+ helper.make_node(
+ "Add",
+ inputs=["mul_output", "A"],
+ outputs=["C"],
+ name="add_1",
+ ),
+ ],
+ name="Main_graph",
+ inputs=[
+ helper.make_tensor_value_info("A", TensorProto.FLOAT, [3, 2]),
+ helper.make_tensor_value_info("B", TensorProto.FLOAT, [3, 2]),
+ ],
+ outputs=[
+ helper.make_tensor_value_info("C", TensorProto.FLOAT, [3, 2]),
+ ],
+)
+
+model = helper.make_model(graph_proto)
+checker.check_model(model, True)
+save(model, "add_mul_add.onnx")
diff --git a/onnxruntime/test/testdata/input_propagated_to_output.onnx b/onnxruntime/test/testdata/input_propagated_to_output.onnx
new file mode 100644
index 0000000000000..feeab10556cb0
Binary files /dev/null and b/onnxruntime/test/testdata/input_propagated_to_output.onnx differ
diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc
index 43f6e480672ba..a04aafecbc81a 100644
--- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc
+++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc
@@ -32,6 +32,31 @@
"^test_adagrad",
"^test_adagrad_multiple",
"^test_attention_3d.*", // wrong expected values in onnx==1.18.0, fixed in 1.19.0
+ "^test_attention_4d_diff_heads_mask4d_padded_kv*", // pending onnx update
+ "^test_attention_3d_gqa*", // pending onnx update
+ "^test_attention_3d_gqa_causal", // pending onnx update
+ "^test_attention_3d_gqa_scaled", // pending onnx update
+ "^test_attention_3d_gqa_softcap", // pending onnx update
+ "^test_attention_3d_gqa_with_past_and_present", // pending onnx update
+ "^test_attention_4d_gqa*", // pending onnx update
+ "^test_attention_4d_gqa_causal", // pending onnx update
+ "^test_attention_4d_gqa_scaled", // pending onnx update
+ "^test_attention_4d_gqa_softcap", // pending onnx update
+ "^test_attention_4d_gqa_with_past_and_present", // pending onnx update
+ "^test_attention_*causal*", // pending onnx update
+ "^test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_causal*", // pending onnx update
+ "^test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal*", // pending onnx update
+ "^test_attention_4d_attn_mask_3d_causal_expanded*", // pending onnx update
+ "^test_attention_4d_fp16*", // precision issue: 1 / 192 mismatched elements
+ "^test_attention_4d_fp16_expanded*", // precision issue: 3 / 192 mismatched elements
+ "^test_l2normalization*", // LpNormalization(22) not implemented
+ "^test_l1normalization*", // LpNormalization(22) not implemented
+ "^test_lpnormalization*", // LpNormalization(22) not implemented
+ "^test_tensorscatter*", // TensorScatter(24) not implemented
+ "^test_castlike_no_saturate_FLOAT_to_FLOAT8*", // ORT does not support ml_dtypes
+ "^test_castlike_UINT4_to*", // ORT does not support ml_dtypes
+ "^test_castlike_INT4_to*", // ORT does not support ml_dtypes
+ "^test_cast_e8m0_*", // ORT does not support float8e8m0
"^test_batchnorm_epsilon_training_mode",
"^test_batchnorm_example_training_mode",
"^test_col2im_pads", // still one wrong value coming from the backtest example
diff --git a/onnxruntime/test/testdata/test_dangling_input_segment_ids.onnx b/onnxruntime/test/testdata/test_dangling_input_segment_ids.onnx
new file mode 100644
index 0000000000000..a83c21030ad67
Binary files /dev/null and b/onnxruntime/test/testdata/test_dangling_input_segment_ids.onnx differ
diff --git a/onnxruntime/test/testdata/test_dangling_input_segment_ids.py b/onnxruntime/test/testdata/test_dangling_input_segment_ids.py
new file mode 100644
index 0000000000000..c5eb8a600d6b5
--- /dev/null
+++ b/onnxruntime/test/testdata/test_dangling_input_segment_ids.py
@@ -0,0 +1,86 @@
+"""
+Run this script to recreate the original onnx model.
+Example usage:
+python test_dangling_input_segment_ids.py out_model_path.onnx
+"""
+
+import os
+import sys
+
+import numpy as np
+import onnx
+from onnx import TensorProto, helper, numpy_helper
+
+DATA_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "test_dangling_input_segment_ids")
+
+
+def order_repeated_field(repeated_proto, key_name, order):
+ order = list(order)
+ repeated_proto.sort(key=lambda x: order.index(getattr(x, key_name)))
+
+
+def make_node(op_type, inputs, outputs, name=None, doc_string=None, domain=None, **kwargs):
+ node = helper.make_node(op_type, inputs, outputs, name, doc_string, domain, **kwargs)
+ if doc_string == "":
+ node.doc_string = ""
+ order_repeated_field(node.attribute, "name", kwargs.keys())
+ return node
+
+
+def make_graph(*args, doc_string=None, **kwargs):
+ graph = helper.make_graph(*args, doc_string=doc_string, **kwargs)
+ if doc_string == "":
+ graph.doc_string = ""
+ return graph
+
+
+model = helper.make_model(
+ opset_imports=[helper.make_operatorsetid("", 14), helper.make_operatorsetid("com.microsoft", 1)],
+ ir_version=7,
+ graph=make_graph(
+ name="embed_layernorm_graph",
+ inputs=[
+ helper.make_tensor_value_info("input_ids", TensorProto.INT32, shape=[1, 4]),
+ helper.make_tensor_value_info("segment_ids", TensorProto.INT32, shape=[1, 4]),
+ ],
+ outputs=[
+ helper.make_tensor_value_info("layernorm_out", TensorProto.FLOAT, shape=[1, 4, 4]),
+ helper.make_tensor_value_info("mask_index_out", TensorProto.INT32, shape=[1]),
+ ],
+ initializer=[
+ numpy_helper.from_array(
+ np.load(os.path.join(DATA_DIR, "const0_word_embed.npy")).astype("float32").reshape([32, 4]),
+ name="word_embed",
+ ),
+ numpy_helper.from_array(
+ np.load(os.path.join(DATA_DIR, "const1_pos_embed.npy")).astype("float32").reshape([16, 4]),
+ name="pos_embed",
+ ),
+ numpy_helper.from_array(
+ np.array(
+ [0.6185135841369629, 0.010364261455833912, 0.5386272668838501, 0.0030179566238075495],
+ dtype="float32",
+ ),
+ name="gamma",
+ ),
+ numpy_helper.from_array(
+ np.array(
+ [0.9511938095092773, 0.9054020047187805, 0.7959669232368469, 0.9152743220329285], dtype="float32"
+ ),
+ name="beta",
+ ),
+ ],
+ nodes=[
+ make_node(
+ "EmbedLayerNormalization",
+ inputs=["input_ids", "", "word_embed", "pos_embed", "", "gamma", "beta"],
+ outputs=["layernorm_out", "mask_index_out"],
+ domain="com.microsoft",
+ )
+ ],
+ ),
+)
+
+if __name__ == "__main__" and len(sys.argv) == 2:
+ _, out_path = sys.argv
+ onnx.save(model, out_path)