diff --git a/cmake/onnxruntime_graph.cmake b/cmake/onnxruntime_graph.cmake index f2ff9f9afec8a..315a66edd1904 100644 --- a/cmake/onnxruntime_graph.cmake +++ b/cmake/onnxruntime_graph.cmake @@ -98,6 +98,10 @@ if (onnxruntime_ENABLE_TRAINING) source_group(TREE ${ORTTRAINING_ROOT} FILES ${orttraining_graph_src}) endif() +if (onnxruntime_BUILD_MS_EXPERIMENTAL_OPS) + target_compile_definitions(onnxruntime_graph PRIVATE BUILD_MS_EXPERIMENTAL_OPS=1) +endif() + if (WIN32) set(onnxruntime_graph_static_library_flags -IGNORE:4221 # LNK4221: This object file does not define any previously undefined public symbols, so it will not be used by any link operation that consumes this library diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 901a12b714c72..a34b9f9f97d9a 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -171,6 +171,10 @@ if (MSVC) endif() onnxruntime_add_include_to_target(onnxruntime_providers onnxruntime_common onnxruntime_framework onnx onnx_proto protobuf::libprotobuf flatbuffers) +if (onnxruntime_BUILD_MS_EXPERIMENTAL_OPS) + target_compile_definitions(onnxruntime_providers PRIVATE BUILD_MS_EXPERIMENTAL_OPS=1) +endif() + if (onnxruntime_USE_FEATURIZERS) add_dependencies(onnxruntime_providers onnxruntime_featurizers) onnxruntime_add_include_to_target(onnxruntime_providers onnxruntime_featurizers) diff --git a/cmake/winml.cmake b/cmake/winml.cmake index 20f25d62340f2..436ef5ca45629 100644 --- a/cmake/winml.cmake +++ b/cmake/winml.cmake @@ -541,8 +541,16 @@ endif(onnxruntime_USE_DML) # Add static library that will be archived/linked for both static/dynamic library add_library(winml_lib_api_experimental STATIC - ${winml_lib_api_experimental_dir}/Dummy.cpp - ${winml_lib_api_experimental_dir}/Dummy.h + ${winml_lib_api_experimental_dir}/LearningModelBuilder.cpp + ${winml_lib_api_experimental_dir}/LearningModelBuilder.h + ${winml_lib_api_experimental_dir}/LearningModelInputs.cpp + ${winml_lib_api_experimental_dir}/LearningModelInputs.h + ${winml_lib_api_experimental_dir}/LearningModelOutputs.cpp + ${winml_lib_api_experimental_dir}/LearningModelOutputs.h + ${winml_lib_api_experimental_dir}/LearningModelOperator.cpp + ${winml_lib_api_experimental_dir}/LearningModelOperator.h + ${winml_lib_api_experimental_dir}/LearningModelOperatorSet.cpp + ${winml_lib_api_experimental_dir}/LearningModelOperatorSet.h ${winml_lib_api_experimental_dir}/LearningModelSessionExperimental.cpp ${winml_lib_api_experimental_dir}/LearningModelSessionExperimental.h ${winml_lib_api_experimental_dir}/LearningModelSessionOptionsExperimental.cpp @@ -568,7 +576,7 @@ target_precompiled_header(winml_lib_api_experimental pch.h) # Includes target_include_directories(winml_lib_api_experimental PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/winml_api) # windows machine learning generated component headers target_include_directories(winml_lib_api_experimental PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/winml_api/comp_generated) # windows machine learning generated component headers -target_include_directories(winml_lib_api_experimental PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/winml_api_experimental) # windows machine learning generated component headers +target_include_directories(winml_lib_api_experimental PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/winml_api_experimental) # windows machine learning generated component headers target_include_directories(winml_lib_api_experimental PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/winml_api_experimental/comp_generated) # windows machine learning generated component headers target_include_directories(winml_lib_api_experimental PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/winml/sdk/cppwinrt/include) # sdk cppwinrt headers diff --git a/cmake/winml_cppwinrt.cmake b/cmake/winml_cppwinrt.cmake index 3c8056cc4abc1..8061033fd251d 100644 --- a/cmake/winml_cppwinrt.cmake +++ b/cmake/winml_cppwinrt.cmake @@ -246,4 +246,4 @@ function(add_generate_cppwinrt_sdk_headers_target set_target_properties(${target_name} PROPERTIES FOLDER ${folder_name}) endif() -endfunction() +endfunction() \ No newline at end of file diff --git a/cmake/winml_unittests.cmake b/cmake/winml_unittests.cmake index b3ebaff1bf119..a92503c64f433 100644 --- a/cmake/winml_unittests.cmake +++ b/cmake/winml_unittests.cmake @@ -61,6 +61,10 @@ function(add_winml_test) target_compile_definitions(${_UT_TARGET} PRIVATE "BUILD_INBOX=1") endif() + if (onnxruntime_BUILD_MS_EXPERIMENTAL_OPS) + target_compile_definitions(${_UT_TARGET} PRIVATE "BUILD_MS_EXPERIMENTAL_OPS=1") + endif() + add_test(NAME ${_UT_TARGET} COMMAND ${_UT_TARGET} WORKING_DIRECTORY $ diff --git a/csharp/src/Microsoft.AI.MachineLearning.Interop/Microsoft.AI.MachineLearning.Interop.csproj b/csharp/src/Microsoft.AI.MachineLearning.Interop/Microsoft.AI.MachineLearning.Interop.csproj index ca31e658f8db8..689d281d025fd 100644 --- a/csharp/src/Microsoft.AI.MachineLearning.Interop/Microsoft.AI.MachineLearning.Interop.csproj +++ b/csharp/src/Microsoft.AI.MachineLearning.Interop/Microsoft.AI.MachineLearning.Interop.csproj @@ -30,6 +30,7 @@ + diff --git a/csharp/src/Microsoft.AI.MachineLearning/Microsoft.AI.MachineLearning.targets b/csharp/src/Microsoft.AI.MachineLearning/Microsoft.AI.MachineLearning.targets index bfecff8c983ee..807cf89e303ef 100644 --- a/csharp/src/Microsoft.AI.MachineLearning/Microsoft.AI.MachineLearning.targets +++ b/csharp/src/Microsoft.AI.MachineLearning/Microsoft.AI.MachineLearning.targets @@ -20,6 +20,9 @@ $(WindowsAIBinary) + + $(WindowsAIBinary) + diff --git a/include/onnxruntime/core/graph/constants.h b/include/onnxruntime/core/graph/constants.h index dc1b4d0bbdd39..8ce6147c090ab 100644 --- a/include/onnxruntime/core/graph/constants.h +++ b/include/onnxruntime/core/graph/constants.h @@ -13,6 +13,7 @@ constexpr const char* kOnnxDomain = ""; constexpr const char* kOnnxDomainAlias = "ai.onnx"; constexpr const char* kMLDomain = "ai.onnx.ml"; constexpr const char* kMSDomain = "com.microsoft"; +constexpr const char* kMSExperimentalDomain = "com.microsoft.experimental"; constexpr const char* kMSNchwcDomain = "com.microsoft.nchwc"; constexpr const char* kMSFeaturizersDomain = "com.microsoft.mlfeaturizers"; constexpr const char* kMSDmlDomain = "com.microsoft.dml"; diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index 2df5a1b9fc13f..66817cfdf61f4 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -34,6 +34,16 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Gelu) class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, BiasGelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FastGelu); +#ifdef BUILD_MS_EXPERIMENTAL_OPS +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSExperimentalDomain, 1, DFT); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSExperimentalDomain, 1, IDFT); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSExperimentalDomain, 1, HannWindow); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSExperimentalDomain, 1, HammingWindow); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSExperimentalDomain, 1, BlackmanWindow); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSExperimentalDomain, 1, MelWeightMatrix); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSExperimentalDomain, 1, STFT); +#endif + // ******** Start: Quantization ******************* // class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulInteger16); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearGlobalAveragePool); @@ -179,6 +189,15 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, +#ifdef BUILD_MS_EXPERIMENTAL_OPS + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, +#endif // These ops were experimental ops in onnx domain which have been removed now. We add them here as // contrib ops to main backward compatibility BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cpu/signal/dft.cc b/onnxruntime/contrib_ops/cpu/signal/dft.cc new file mode 100644 index 0000000000000..267290dfba4a2 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/signal/dft.cc @@ -0,0 +1,521 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef BUILD_MS_EXPERIMENTAL_OPS + +#include "core/providers/common.h" +#include "core/framework/op_kernel.h" +#include "core/util/math_cpuonly.h" +#include "Eigen/src/Core/Map.h" +#include "dft.h" +#include + +#include "core/platform/threadpool.h" + +#include + +namespace onnxruntime { +namespace contrib { + +ONNX_OPERATOR_KERNEL_EX( + DFT, + kMSExperimentalDomain, + 1, + kCpuExecutionProvider, + KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", BuildKernelDefConstraints()), + DFT); + +ONNX_OPERATOR_KERNEL_EX( + IDFT, + kMSExperimentalDomain, + 1, + kCpuExecutionProvider, + KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", BuildKernelDefConstraints()), + IDFT); + +ONNX_OPERATOR_KERNEL_EX( + STFT, + kMSExperimentalDomain, + 1, + kCpuExecutionProvider, + KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", BuildKernelDefConstraints()), + STFT); + +static bool is_real_valued_signal(const onnxruntime::TensorShape & shape) { + // The first dimention is the batch size + // The second dimention is the signal value + return shape.NumDimensions() == 2; +} + +static bool is_complex_valued_signal(const onnxruntime::TensorShape& shape) { + // The first dimention is the batch size + // The second dimention is the signal length + // The third dimention is set to 2 and represents the real and imaginary parts of the complex sample + return shape.NumDimensions() == 3 && shape[2] == 2; +} + +static bool is_power_of_2(size_t size) { + unsigned n_bits = 0; + while (size != 0) { + n_bits += size & 1; + size = size >> 1; + } + return n_bits == 1; +} + +static const unsigned char BitReverseTable256[] = +{ + 0x00, 0x80, 0x40, 0xC0, 0x20, 0xA0, 0x60, 0xE0, 0x10, 0x90, 0x50, 0xD0, 0x30, 0xB0, 0x70, 0xF0, + 0x08, 0x88, 0x48, 0xC8, 0x28, 0xA8, 0x68, 0xE8, 0x18, 0x98, 0x58, 0xD8, 0x38, 0xB8, 0x78, 0xF8, + 0x04, 0x84, 0x44, 0xC4, 0x24, 0xA4, 0x64, 0xE4, 0x14, 0x94, 0x54, 0xD4, 0x34, 0xB4, 0x74, 0xF4, + 0x0C, 0x8C, 0x4C, 0xCC, 0x2C, 0xAC, 0x6C, 0xEC, 0x1C, 0x9C, 0x5C, 0xDC, 0x3C, 0xBC, 0x7C, 0xFC, + 0x02, 0x82, 0x42, 0xC2, 0x22, 0xA2, 0x62, 0xE2, 0x12, 0x92, 0x52, 0xD2, 0x32, 0xB2, 0x72, 0xF2, + 0x0A, 0x8A, 0x4A, 0xCA, 0x2A, 0xAA, 0x6A, 0xEA, 0x1A, 0x9A, 0x5A, 0xDA, 0x3A, 0xBA, 0x7A, 0xFA, + 0x06, 0x86, 0x46, 0xC6, 0x26, 0xA6, 0x66, 0xE6, 0x16, 0x96, 0x56, 0xD6, 0x36, 0xB6, 0x76, 0xF6, + 0x0E, 0x8E, 0x4E, 0xCE, 0x2E, 0xAE, 0x6E, 0xEE, 0x1E, 0x9E, 0x5E, 0xDE, 0x3E, 0xBE, 0x7E, 0xFE, + 0x01, 0x81, 0x41, 0xC1, 0x21, 0xA1, 0x61, 0xE1, 0x11, 0x91, 0x51, 0xD1, 0x31, 0xB1, 0x71, 0xF1, + 0x09, 0x89, 0x49, 0xC9, 0x29, 0xA9, 0x69, 0xE9, 0x19, 0x99, 0x59, 0xD9, 0x39, 0xB9, 0x79, 0xF9, + 0x05, 0x85, 0x45, 0xC5, 0x25, 0xA5, 0x65, 0xE5, 0x15, 0x95, 0x55, 0xD5, 0x35, 0xB5, 0x75, 0xF5, + 0x0D, 0x8D, 0x4D, 0xCD, 0x2D, 0xAD, 0x6D, 0xED, 0x1D, 0x9D, 0x5D, 0xDD, 0x3D, 0xBD, 0x7D, 0xFD, + 0x03, 0x83, 0x43, 0xC3, 0x23, 0xA3, 0x63, 0xE3, 0x13, 0x93, 0x53, 0xD3, 0x33, 0xB3, 0x73, 0xF3, + 0x0B, 0x8B, 0x4B, 0xCB, 0x2B, 0xAB, 0x6B, 0xEB, 0x1B, 0x9B, 0x5B, 0xDB, 0x3B, 0xBB, 0x7B, 0xFB, + 0x07, 0x87, 0x47, 0xC7, 0x27, 0xA7, 0x67, 0xE7, 0x17, 0x97, 0x57, 0xD7, 0x37, 0xB7, 0x77, 0xF7, + 0x0F, 0x8F, 0x4F, 0xCF, 0x2F, 0xAF, 0x6F, 0xEF, 0x1F, 0x9F, 0x5F, 0xDF, 0x3F, 0xBF, 0x7F, 0xFF}; + +template +uint32_t bit_reverse(uint32_t num) { + uint32_t rev = (BitReverseTable256[num & 0xff] << 24) | + (BitReverseTable256[(num >> 8) & 0xff] << 16) | + (BitReverseTable256[(num >> 16) & 0xff] << 8) | + (BitReverseTable256[(num >> 24) & 0xff]); + return static_cast(((uint64_t)rev) >> (32 - TSignificantBits)); +} + +template +static inline T bit_reverse(T num, unsigned significant_bits) { + switch (significant_bits) { + case 0: return static_cast(bit_reverse<0>(static_cast(num))); + case 1: return static_cast(bit_reverse<1>(static_cast(num))); + case 2: return static_cast(bit_reverse<2>(static_cast(num))); + case 3: return static_cast(bit_reverse<3>(static_cast(num))); + case 4: return static_cast(bit_reverse<4>(static_cast(num))); + case 5: return static_cast(bit_reverse<5>(static_cast(num))); + case 6: return static_cast(bit_reverse<6>(static_cast(num))); + case 7: return static_cast(bit_reverse<7>(static_cast(num))); + case 8: return static_cast(bit_reverse<8>(static_cast(num))); + case 9: return static_cast(bit_reverse<9>(static_cast(num))); + case 10: return static_cast(bit_reverse<10>(static_cast(num))); + case 11: return static_cast(bit_reverse<11>(static_cast(num))); + case 12: return static_cast(bit_reverse<12>(static_cast(num))); + case 13: return static_cast(bit_reverse<13>(static_cast(num))); + case 14: return static_cast(bit_reverse<14>(static_cast(num))); + case 15: return static_cast(bit_reverse<15>(static_cast(num))); + case 16: return static_cast(bit_reverse<16>(static_cast(num))); + case 17: return static_cast(bit_reverse<17>(static_cast(num))); + case 18: return static_cast(bit_reverse<18>(static_cast(num))); + case 19: return static_cast(bit_reverse<19>(static_cast(num))); + case 20: return static_cast(bit_reverse<20>(static_cast(num))); + case 21: return static_cast(bit_reverse<21>(static_cast(num))); + case 22: return static_cast(bit_reverse<22>(static_cast(num))); + case 23: return static_cast(bit_reverse<23>(static_cast(num))); + case 24: return static_cast(bit_reverse<24>(static_cast(num))); + case 25: return static_cast(bit_reverse<25>(static_cast(num))); + case 26: return static_cast(bit_reverse<26>(static_cast(num))); + case 27: return static_cast(bit_reverse<27>(static_cast(num))); + case 28: return static_cast(bit_reverse<28>(static_cast(num))); + case 29: return static_cast(bit_reverse<29>(static_cast(num))); + case 30: return static_cast(bit_reverse<30>(static_cast(num))); + case 31: return static_cast(bit_reverse<31>(static_cast(num))); + case 32: return static_cast(bit_reverse<32>(static_cast(num))); + default: ORT_THROW("Unsupported bit size."); + } +} + +template +static T compute_angular_velocity(size_t number_of_samples, bool inverse) { + // Calculate fundamental angular velocity + static const T pi = static_cast(3.14159265); + static const T tau = 2 * pi; + T inverse_switch = inverse ? 1.f : -1.f; + T angular_velocity = inverse_switch * tau / number_of_samples; + return angular_velocity; +} + +template +static Status fft_radix2(OpKernelContext* /*ctx*/, size_t batch_idx, + const Tensor* X, Tensor* Y, const Tensor* window, bool is_onesided, bool inverse, + std::vector>& V, + std::vector>& temp_output) { + + // Get shape and significant bits + const auto& X_shape = X->Shape(); + size_t number_of_samples = static_cast(X_shape[1]); + unsigned significant_bits = static_cast(log2(number_of_samples)); + + // Get data + auto* X_data = const_cast(reinterpret_cast(X->DataRaw())) + (batch_idx * number_of_samples); + // Get window + U* window_data = nullptr; + if (window) { + window_data = const_cast(reinterpret_cast(window->DataRaw())); + } + + std::complex* Y_data; + if (is_onesided) { + if (temp_output.size() != number_of_samples) { + temp_output = std::vector>(number_of_samples); + } + Y_data = temp_output.data(); + } else { + Y_data = reinterpret_cast*>(Y->MutableDataRaw()) + (batch_idx * number_of_samples); + } + + auto angular_velocity = compute_angular_velocity(number_of_samples, inverse); + + // Create vandermonde matrix V ordered with the bit-reversed permutation + if (V.size() != number_of_samples) { + V = std::vector>(number_of_samples); // e^(i *2*pi / N * k) + for (size_t i = 0; i < number_of_samples; i++) { + size_t bit_reversed_index = bit_reverse(i, significant_bits); + V[bit_reversed_index] = std::complex(cos(i * angular_velocity), sin(i * angular_velocity)); + } + } + + for (size_t i = 0; i < number_of_samples; i++) { + size_t bit_reversed_index = bit_reverse(i, significant_bits); + auto x = *(X_data + bit_reversed_index); + auto window_element = window_data ? *(window_data + bit_reversed_index) : 1; + *(Y_data + i) = std::complex(1, 0) * x * window_element; + } + + // Run fft_radix2 + unsigned current_significant_bits = 0; + for (size_t i = 2; i <= number_of_samples; i <<= 1) { + size_t midpoint = i >> 1; + current_significant_bits++; + + for (size_t k = 0; k < midpoint; k++) { + auto first_idx = bit_reverse(k, current_significant_bits); + auto second_idx = bit_reverse(midpoint + k, current_significant_bits); + for (size_t j = 0; j < number_of_samples; j += i) { + std::complex* even = (Y_data + j) + k; + std::complex* odd = (Y_data + j) + (midpoint + k); + std::complex first = *even + (V[first_idx] * *odd); + std::complex second = *even + (V[second_idx] * *odd); + *even = first; + *odd = second; + } + } + } + + // Scale the output if inverse + if (inverse) { + for (size_t i = 0; i < number_of_samples; i++) { + std::complex& val = *(Y_data + i); + val /= static_cast(number_of_samples); + } + } + + if (is_onesided) { + const auto& Y_shape = Y->Shape(); + size_t fft_output_size = static_cast(Y_shape[1]); + auto destination = reinterpret_cast*>(Y->MutableDataRaw()) + (batch_idx * fft_output_size); + memcpy(destination, Y_data, sizeof(std::complex) * fft_output_size); + } + + return Status::OK(); +} + +template +static Status dft_naive(size_t batch_idx, const Tensor* X, Tensor* Y, const Tensor* window, bool inverse) { + // Get shape and significant bits + const auto& X_shape = X->Shape(); + size_t number_of_samples = static_cast(X_shape[1]); + const auto& Y_shape = Y->Shape(); + size_t dft_output_size = static_cast(Y_shape[1]); + + // Get data + auto* X_data = const_cast(reinterpret_cast(X->DataRaw())) + (batch_idx * number_of_samples); + auto* Y_data = reinterpret_cast*>(Y->MutableDataRaw()) + (batch_idx * dft_output_size); + + U* window_data = nullptr; + if (window) { + window_data = const_cast(reinterpret_cast(window->DataRaw())); + } + + auto angular_velocity = compute_angular_velocity(number_of_samples, inverse); + + for (size_t i = 0; i < dft_output_size; i++) { + std::complex& out = *(Y_data + i); + out.real(0); + out.imag(0); + + for (size_t j = 0; j < number_of_samples; j++) { // vectorize over this loop + auto exponential = std::complex(cos(i * j * angular_velocity), sin(i * j * angular_velocity)); + auto window_element = window_data ? * (window_data + j) : 1; + auto element = *(X_data + j) * window_element; + out += exponential * element; + } + + if (inverse) { + out /= static_cast(number_of_samples); + } + } + + return Status::OK(); +} + +template +static Status discrete_fourier_transform(OpKernelContext* ctx, const Tensor* X, Tensor* Y, const Tensor* window, bool is_onesided, bool inverse, + std::vector>& V, std::vector>& temp_output) { + // Get shape + const auto& X_shape = X->Shape(); + size_t number_of_batches = static_cast(X_shape[0]); + size_t number_of_samples = static_cast(X_shape[1]); + + // radix 2 fft + for (size_t i = 0; i < number_of_batches; i++) { + if (is_power_of_2(number_of_samples)) { + ORT_RETURN_IF_ERROR((fft_radix2(ctx, i, X, Y, window, is_onesided, inverse, V, temp_output))); + } else { + ORT_RETURN_IF_ERROR((dft_naive(i, X, Y, window, inverse))); + } + } + + return Status::OK(); +} + +static Status discrete_fourier_transform(OpKernelContext* ctx, bool is_onesided, bool inverse) { + // Get input shape + const auto* X = ctx->Input(0); + const auto& X_shape = X->Shape(); + const auto is_real_valued = is_real_valued_signal(X_shape); + const auto is_complex_valued = is_complex_valued_signal(X_shape); + + // Get the DFT output size. Onesided will return only the unique values! + // note: x >> 1 === std::floor(x / 2.f) + int64_t number_of_samples = static_cast(X_shape[1]); + auto dft_output_size = is_onesided ? + ((number_of_samples >> 1) + 1) : + number_of_samples; + + // Get output shape + auto Y_shape = onnxruntime::TensorShape({X_shape[0], dft_output_size, 2}); + auto Y = ctx->Output(0, Y_shape); + + // Get data type + auto data_type = X->DataType(); + + auto element_size = data_type->Size(); + if (element_size == sizeof(float)) { + std::vector> V; + std::vector> temp_output; + if (is_real_valued) { + ORT_RETURN_IF_ERROR((discrete_fourier_transform(ctx, X, Y, nullptr, is_onesided, inverse, V, temp_output))); + } else if (is_complex_valued) { + ORT_RETURN_IF_ERROR((discrete_fourier_transform>(ctx, X, Y, nullptr, is_onesided, inverse, V, temp_output))); + } else { + ORT_THROW("Unsupported input signal shape. The signal's first dimenstion must be the batch dimension and its second dimension must be the signal length dimension. It may optionally include a 3rd dimension of size 2 for complex inputs.", data_type); + } + } else if (element_size == sizeof(double)) { + std::vector> V; + std::vector> temp_output; + if (is_real_valued) { + ORT_RETURN_IF_ERROR((discrete_fourier_transform(ctx, X, Y, nullptr, is_onesided, inverse, V, temp_output))); + } else if (is_complex_valued) { + ORT_RETURN_IF_ERROR((discrete_fourier_transform>(ctx, X, Y, nullptr, is_onesided, inverse, V, temp_output))); + } else { + ORT_THROW("Unsupported input signal shape. The signal's first dimenstion must be the batch dimension and its second dimension must be the signal length dimension. It may optionally include a 3rd dimension of size 2 for complex inputs.", data_type); + } + } else { + ORT_THROW("Unsupported input data type of ", data_type); + } + + return Status::OK(); +} + +Status DFT::Compute(OpKernelContext* ctx) const { + ORT_RETURN_IF_ERROR(discrete_fourier_transform(ctx, is_onesided_, false)); + return Status::OK(); +} + +Status IDFT::Compute(OpKernelContext* ctx) const { + ORT_RETURN_IF_ERROR(discrete_fourier_transform(ctx, false, true)); + return Status::OK(); +} + +// dedupe with the other one in window_functions.cc +template +static T get_scalar_value_from_tensor(const Tensor* tensor) { + ORT_ENFORCE(tensor->Shape().Size() == 1, "ratio input should have a single value."); + + auto data_type = tensor->DataType()->AsPrimitiveDataType()->GetDataType(); + switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + return static_cast(*reinterpret_cast(tensor->DataRaw())); + case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: + return static_cast(*reinterpret_cast(tensor->DataRaw())); + case ONNX_NAMESPACE::TensorProto_DataType_INT32: + return static_cast(*reinterpret_cast(tensor->DataRaw())); + case ONNX_NAMESPACE::TensorProto_DataType_INT64: + return static_cast(*reinterpret_cast(tensor->DataRaw())); + default: + ORT_THROW("Unsupported input data type of ", data_type); + } +} + +template +static Status short_time_fourier_transform(OpKernelContext* ctx, bool is_onesided, bool /*inverse*/) { + // Attr("onesided"): default = 1 + // Input(0, "signal") type = T1 + // Input(1, "frame_length") type = T2 + // Input(2, "window") type = T1, optional + // Input(3, "frame_step") type = T2 + // Output(0, "output") type = T1 + + // Get signal + const auto* signal = ctx->Input(0); + const auto* window = ctx->Input(1); + const auto* frame_length_tensor = ctx->Input(2); + const auto frame_step = get_scalar_value_from_tensor(ctx->Input(3)); + + // Get input signal shape + const auto& signal_shape = signal->Shape(); + const auto batch_size = signal_shape[0]; + const auto signal_size = signal_shape[1]; + const auto signal_components = + signal_shape.NumDimensions() == 2 ? 1 : signal_shape.NumDimensions() == 3 ? signal_shape[2] : 0; // error + ORT_ENFORCE(signal_components == 1 || signal_components == 2, "Ensure that the signal has either 1 or 2 components."); + + // Get the frame length + int64_t frame_length = std::numeric_limits::min(); + if (frame_length_tensor) + { + frame_length = get_scalar_value_from_tensor(frame_length_tensor); + } + + // Get window length + int64_t window_length = std::numeric_limits::min(); + if (window) { + window_length = window->Shape()[0]; + } + + // The frame_length and window inputs are generally used interchangably, and should match! + if (frame_length != std::numeric_limits::min() && + window_length != std::numeric_limits::min()) { + ORT_ENFORCE(frame_length == window_length, "If both frame_length and window are set, then the size of the window must be equal to the frame_length."); + } + + // Calculate the window size with preference to the window input. + const auto window_size = window ? window->Shape()[0] : frame_length; + ORT_ENFORCE(window_size < signal_size, "Ensure that the dft size is smaller than the signal."); + + // Calculate the number of dfts to run + const auto n_dfts = static_cast(std::floor((signal_size - window_size) / static_cast(frame_step)) + 1); + + // Calculate the output spectra length (onesided will return only the unique values) + // note: x >> 1 === std::floor(x / 2.f) + const auto dft_output_size = + is_onesided ? + (window_size >> 1) + 1 : + window_size; + + // Get/create the output mutable data + auto output_spectra_shape = onnxruntime::TensorShape({batch_size, n_dfts, dft_output_size, 2}); + auto Y = ctx->Output(0, output_spectra_shape); + auto Y_data = reinterpret_cast(Y->MutableDataRaw()); + + // Get/create the signal mutable data + auto* signal_data = const_cast(reinterpret_cast(signal->DataRaw())); + + // Define tensor shapes for each dft run + const int64_t output_components = 2; + auto dft_input_shape = onnxruntime::TensorShape({1, window_size, signal_components}); + auto dft_output_shape = onnxruntime::TensorShape({1, dft_output_size, output_components}); + + std::vector> V; + std::vector> temp_output; + + // Run each dft of each batch as if it was a real-valued batch size 1 dft operation + for (int64_t batch_idx = 0; batch_idx < batch_size; batch_idx++) { + for (int64_t i = 0; i < n_dfts; i++) { + auto input_frame_begin = + signal_data + + (batch_idx * signal_size * signal_components) + + (i * frame_step * signal_components); + + auto output_frame_begin = + Y_data + + (batch_idx * n_dfts * dft_output_size * output_components) + + (i * dft_output_size * output_components); + + // Tensors do not own the backing memory, so no worries on destruction + auto input = + onnxruntime::Tensor( + signal->DataType(), + dft_input_shape, + input_frame_begin, + signal->Location(), + 0); + + auto output = + onnxruntime::Tensor( + Y->DataType(), + dft_output_shape, + output_frame_begin, + Y->Location(), + 0); + + // Run individual dft + ORT_RETURN_IF_ERROR((discrete_fourier_transform(ctx, &input, &output, window, is_onesided, false, V, temp_output))); + } + } + + return Status::OK(); +} + +Status STFT::Compute(OpKernelContext* ctx) const { + // Attr("onesided"): default = 1 + // Input(0, "signal") type = T1 + // Input(1, "frame_length") type = T2 + // Input(2, "window") type = T1, optional + // Input(3, "frame_step") type = T2 + // Output(0, "output") type = T1 + + // Get signal shape + const auto* signal = ctx->Input(0); + const auto& signal_shape = signal->Shape(); + const auto is_real_valued = is_real_valued_signal(signal_shape); + const auto is_complex_valued = is_complex_valued_signal(signal_shape); + + // Get data type + auto data_type = signal->DataType(); + + const auto element_size = data_type->Size(); + if (element_size == sizeof(float)) { + if (is_real_valued) { + ORT_RETURN_IF_ERROR((short_time_fourier_transform(ctx, is_onesided_, false))); + } else if (is_complex_valued) { + ORT_RETURN_IF_ERROR((short_time_fourier_transform>(ctx, is_onesided_, false))); + } else { + ORT_THROW("Unsupported input signal shape. The signal's first dimenstion must be the batch dimension and its second dimension must be the signal length dimension. It may optionally include a 3rd dimension of size 2 for complex inputs.", data_type); + } + } else if (element_size == sizeof(double)) { + if (is_real_valued) { + ORT_RETURN_IF_ERROR((short_time_fourier_transform(ctx, is_onesided_, false))); + } else if (is_complex_valued) { + ORT_RETURN_IF_ERROR((short_time_fourier_transform>(ctx, is_onesided_, false))); + } else { + ORT_THROW("Unsupported input signal shape. The signal's first dimenstion must be the batch dimension and its second dimension must be the signal length dimension. It may optionally include a 3rd dimension of size 2 for complex inputs.", data_type); + } + } else { + ORT_THROW("Unsupported input data type of ", data_type); + } + + return Status::OK(); +} + +} // namespace contrib +} // namespace onnxruntime + +#endif \ No newline at end of file diff --git a/onnxruntime/contrib_ops/cpu/signal/dft.h b/onnxruntime/contrib_ops/cpu/signal/dft.h new file mode 100644 index 0000000000000..a3883bd4b4490 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/signal/dft.h @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef BUILD_MS_EXPERIMENTAL_OPS + +namespace onnxruntime { +namespace contrib { + +class DFT final : public OpKernel { + bool is_onesided_ = true; + public: + explicit DFT(const OpKernelInfo& info) : OpKernel(info) { + is_onesided_ = info.GetAttrOrDefault("onesided", 0); + } + Status Compute(OpKernelContext* ctx) const override; +}; + +class IDFT final : public OpKernel { + public: + explicit IDFT(const OpKernelInfo& info) : OpKernel(info) { + } + Status Compute(OpKernelContext* ctx) const override; +}; + +class STFT final : public OpKernel { + bool is_onesided_ = true; + public: + explicit STFT(const OpKernelInfo& info) : OpKernel(info) { + is_onesided_ = info.GetAttrOrDefault("onesided", 1); + } + Status Compute(OpKernelContext* ctx) const override; +}; + +} // namespace contrib +} // namespace onnxruntime + +#endif \ No newline at end of file diff --git a/onnxruntime/contrib_ops/cpu/signal/window_functions.cc b/onnxruntime/contrib_ops/cpu/signal/window_functions.cc new file mode 100644 index 0000000000000..32057f198fabc --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/signal/window_functions.cc @@ -0,0 +1,333 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef BUILD_MS_EXPERIMENTAL_OPS + +#include "core/providers/common.h" +#include "core/framework/op_kernel.h" +#include "core/util/math_cpuonly.h" +#include "Eigen/src/Core/Map.h" +#include "window_functions.h" +#include + +#include "core/platform/threadpool.h" + +#include + +namespace onnxruntime { +namespace contrib { + +ONNX_OPERATOR_KERNEL_EX( + HannWindow, + kMSExperimentalDomain, + 1, + kCpuExecutionProvider, + KernelDefBuilder().MayInplace(0, 0) + .TypeConstraint("T1", BuildKernelDefConstraints()) + .TypeConstraint("T2", BuildKernelDefConstraints()), + HannWindow); + +ONNX_OPERATOR_KERNEL_EX( + HammingWindow, + kMSExperimentalDomain, + 1, + kCpuExecutionProvider, + KernelDefBuilder().MayInplace(0, 0) + .TypeConstraint("T1", BuildKernelDefConstraints()) + .TypeConstraint("T2", BuildKernelDefConstraints()), + HammingWindow); + +ONNX_OPERATOR_KERNEL_EX( + BlackmanWindow, + kMSExperimentalDomain, + 1, + kCpuExecutionProvider, + KernelDefBuilder().MayInplace(0, 0) + .TypeConstraint("T1", BuildKernelDefConstraints()) + .TypeConstraint("T2", BuildKernelDefConstraints()), + BlackmanWindow); + + +ONNX_OPERATOR_KERNEL_EX( + MelWeightMatrix, + kMSExperimentalDomain, + 1, + kCpuExecutionProvider, + KernelDefBuilder().MayInplace(0, 0) + .TypeConstraint("T1", BuildKernelDefConstraints()) + .TypeConstraint("T2", BuildKernelDefConstraints()) + .TypeConstraint("T3", BuildKernelDefConstraints()), + MelWeightMatrix); + + +template +static Status cosine_sum_window(Tensor* Y, size_t size, float a0, float a1, float a2) { + auto* Y_data = reinterpret_cast(Y->MutableDataRaw()); + + // Calculate the radians to increment per sample + constexpr double pi = 3.14159265; + constexpr double tau = 2 * pi; + const double angular_increment = tau / size; + + for (size_t i = 0; i < size; i++) { + auto a2_component = a2 == 0 ? 0 : (a2 * cos(2 * angular_increment * i)); + + T& value = *(Y_data + i); + value = static_cast(a0 - (a1 * cos(angular_increment * i)) + a2_component); + } + + return Status::OK(); +} + +template +static T get_scalar_value_from_tensor(const Tensor* tensor) { + ORT_ENFORCE(tensor->Shape().Size() == 1, "Tensor input should have a single value."); + auto data_type = tensor->DataType()->AsPrimitiveDataType()->GetDataType(); + switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + return static_cast(*reinterpret_cast(tensor->DataRaw())); + case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: + return static_cast(*reinterpret_cast(tensor->DataRaw())); + case ONNX_NAMESPACE::TensorProto_DataType_INT32: + return static_cast(*reinterpret_cast(tensor->DataRaw())); + case ONNX_NAMESPACE::TensorProto_DataType_INT64: + return static_cast(*reinterpret_cast(tensor->DataRaw())); + default: + ORT_THROW("Unsupported input data type of ", data_type); + } +} + +static Status create_cosine_sum_window( + OpKernelContext* ctx, + onnx::TensorProto_DataType output_datatype, + float a0, float a1, float a2) { + + // Get the size of the window + auto size = get_scalar_value_from_tensor(ctx->Input(0)); + + // Get the output tensor + auto Y_shape = onnxruntime::TensorShape({size}); + auto Y = ctx->Output(0, Y_shape); + + switch (output_datatype) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { + ORT_RETURN_IF_ERROR((cosine_sum_window(Y, size, a0, a1, a2))); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: { + ORT_RETURN_IF_ERROR((cosine_sum_window(Y, size, a0, a1, a2))); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_INT8: { + ORT_RETURN_IF_ERROR((cosine_sum_window(Y, size, a0, a1, a2))); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_INT16: { + ORT_RETURN_IF_ERROR((cosine_sum_window(Y, size, a0, a1, a2))); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_INT32: { + ORT_RETURN_IF_ERROR((cosine_sum_window(Y, size, a0, a1, a2))); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_INT64: { + ORT_RETURN_IF_ERROR((cosine_sum_window(Y, size, a0, a1, a2))); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: { + ORT_RETURN_IF_ERROR((cosine_sum_window(Y, size, a0, a1, a2))); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_UINT16: { + ORT_RETURN_IF_ERROR((cosine_sum_window(Y, size, a0, a1, a2))); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_UINT32: { + ORT_RETURN_IF_ERROR((cosine_sum_window(Y, size, a0, a1, a2))); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_UINT64: { + ORT_RETURN_IF_ERROR((cosine_sum_window(Y, size, a0, a1, a2))); + break; + } + default: + ORT_THROW("Unsupported input data type of ", output_datatype); + } + + return Status::OK(); +} + +Status HannWindow::Compute(OpKernelContext* ctx) const { + // HannWindows are a special case of Cosine-Sum Windows which take the following form: + // w[n] = SUM_k=0_K( (-1)^k * a_k * cos(2*pi*k*n/N) ) with values the following values for a_k: + float a0 = .5f; + float a1 = a0; + float a2 = 0; + return create_cosine_sum_window(ctx, data_type_, a0, a1, a2); +} + +Status HammingWindow::Compute(OpKernelContext* ctx) const { + // HammingWindows are a special case of Cosine-Sum Windows which take the following form: + // w[n] = SUM_k=0_K( (-1)^k * a_k * cos(2*pi*k*n/N) ) with values the following values for a_k: + float a0 = 25.f / 46.f; + float a1 = 1 - a0; + float a2 = 0; + return create_cosine_sum_window(ctx, data_type_, a0, a1, a2); +} + +Status BlackmanWindow::Compute(OpKernelContext* ctx) const { + // BlackmanWindows are a special case of Cosine-Sum Windows which take the following form: + // w[n] = SUM_k=0_K( (-1)^k * a_k * cos(2*pi*k*n/N) ) with values the following values for a_k: + float alpha = .16f; + float a2 = alpha / 2.f; + float a0 = .5f - a2; + float a1 = .5f; + return create_cosine_sum_window(ctx, data_type_, a0, a1, a2); +} + +static inline double hz_to_mel_scale(double hz) { + return 2595 * std::log10(1 + hz / 700); +} + +static inline double mel_scale_to_hz(double mels) { + return 700 * (pow(10, (mels / 2595)) - 1); +} + +template +Status create_mel_weight_matrix(OpKernelContext* ctx, int64_t num_mel_bins, int64_t dft_length, int64_t sample_rate, float lower_edge_hertz, float upper_edge_hertz) { + // Determine the width of the spectrogram. + // This is determined as half the size of the fft size. The first element of the spectrum is always retained, + // and the remaining are halved. The second half can be discarded due to the conjugate symmetry of the output with real valued ffts. + // Taken together the formula for the size of the output will be std::floor(dft_length / 2) + 1. + int64_t num_spectrogram_bins = static_cast(std::floor(dft_length / 2 + 1)); + + // Checks + auto lowest_index = std::floor(((dft_length + 1) * lower_edge_hertz) / sample_rate); + auto highest_index = std::floor(((dft_length + 1) * upper_edge_hertz) / sample_rate); + ORT_ENFORCE(lowest_index >= 0 && lowest_index < num_spectrogram_bins, "lower_edge_hertz produces a mel triangle filter bank that is out of range given the dft_length and the sample_rate."); + ORT_ENFORCE(highest_index >= 0 && highest_index < num_spectrogram_bins, "upper_edge_hertz produces a mel triangle filter bank that is out of range given the dft_length and the sample_rate."); + + // Create the output shape + onnxruntime::TensorShape output_shape( + { + static_cast(num_spectrogram_bins), + num_mel_bins + }); + auto* Y = ctx->Output(0, output_shape); + + // Get the raw output data + auto* Y_data = reinterpret_cast(Y->MutableDataRaw()); + + // Set the weight matrix to 0 + memset(Y_data, 0, num_spectrogram_bins * num_mel_bins * sizeof(T)); + + // The mel filterbank is a triangular shaped peak with a height of 1 and a base equal to the size of the MEL range divided by + // the number of bins needed times 2. This triagle is then slid across the mel domain linearly, with a constant step size that + // is equal to half of the base of the triange. To accomodate N bins, N+2 data points will be needed to determine the + // start, center and end points of each mel triange filter. + // + // low_frequency where the mel triangle filter banks begin, and they end on the high_frequency_mel + // The range is divided evenly to create the needed points corresponding to the begin, center, end points of each triangle filterbank + std::vector frequency_bins(num_mel_bins + 2); + auto low_frequency_mel = hz_to_mel_scale(lower_edge_hertz); + auto high_frequency_mel = hz_to_mel_scale(upper_edge_hertz); + auto mel_step = (high_frequency_mel - low_frequency_mel) / static_cast(frequency_bins.size()); + + // Convert each point from mel scale back to hertz, and then compute the corresponding index in the fft + for (size_t i = 0; i < frequency_bins.size(); i++) { + auto hz = mel_scale_to_hz(low_frequency_mel + mel_step * i); + frequency_bins[i] = static_cast(std::floor(((dft_length + 1) * hz) / sample_rate)); + } + + for (size_t i = 0; i < static_cast(num_mel_bins); i++) { + auto lower_frequency_value = frequency_bins[i]; //left + auto center_frequency_point = frequency_bins[i+1]; //center + auto higher_frequency_point = frequency_bins[i+2]; //right + + auto low_to_center = center_frequency_point - lower_frequency_value; + if (low_to_center == 0) { + auto& current_element = *(Y_data + (center_frequency_point * num_mel_bins) + i); + current_element = static_cast(1); + } else { + for (size_t j = lower_frequency_value; j <= center_frequency_point; j++) { + auto& current_element = *(Y_data + (j * num_mel_bins) + i); + current_element = static_cast((j - lower_frequency_value) / static_cast(low_to_center)); + } + } + + auto center_to_high = higher_frequency_point - center_frequency_point; + if (center_to_high > 0) { + for (size_t j = center_frequency_point; j < higher_frequency_point; j++) { + auto& current_element = *(Y_data + (j * num_mel_bins) + i); + current_element = static_cast((higher_frequency_point - j) / static_cast(center_to_high)); + } + } + } + + return Status::OK(); +} + +static Status create_mel_weight_matrix(OpKernelContext* ctx, onnx::TensorProto_DataType output_datatype, + int64_t num_mel_bins, int64_t dft_length, int64_t sample_rate, float lower_edge_hertz, float upper_edge_hertz) { + switch (output_datatype) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { + ORT_RETURN_IF_ERROR((create_mel_weight_matrix(ctx, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz))); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: { + ORT_RETURN_IF_ERROR((create_mel_weight_matrix(ctx, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz))); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_INT8: { + ORT_RETURN_IF_ERROR((create_mel_weight_matrix(ctx, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz))); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_INT16: { + ORT_RETURN_IF_ERROR((create_mel_weight_matrix(ctx, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz))); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_INT32: { + ORT_RETURN_IF_ERROR((create_mel_weight_matrix(ctx, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz))); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_INT64: { + ORT_RETURN_IF_ERROR((create_mel_weight_matrix(ctx, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz))); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: { + ORT_RETURN_IF_ERROR((create_mel_weight_matrix(ctx, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz))); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_UINT16: { + ORT_RETURN_IF_ERROR((create_mel_weight_matrix(ctx, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz))); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_UINT32: { + ORT_RETURN_IF_ERROR((create_mel_weight_matrix(ctx, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz))); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_UINT64: { + ORT_RETURN_IF_ERROR((create_mel_weight_matrix(ctx, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz))); + break; + } + default: + ORT_THROW("Unsupported input data type of ", output_datatype); + } + return Status::OK(); +} + +Status MelWeightMatrix::Compute(OpKernelContext* ctx) const { + const auto num_mel_bins = get_scalar_value_from_tensor(ctx->Input(0)); + const auto dft_length = get_scalar_value_from_tensor(ctx->Input(1)); + const auto sample_rate = get_scalar_value_from_tensor(ctx->Input(2)); + const auto lower_edge_hertz = get_scalar_value_from_tensor(ctx->Input(3)); + const auto upper_edge_hertz = get_scalar_value_from_tensor(ctx->Input(4)); + + ORT_RETURN_IF_ERROR(create_mel_weight_matrix(ctx, data_type_, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz)); + return Status::OK(); +} + +} // namespace contrib +} // namespace onnxruntime + +#endif \ No newline at end of file diff --git a/onnxruntime/contrib_ops/cpu/signal/window_functions.h b/onnxruntime/contrib_ops/cpu/signal/window_functions.h new file mode 100644 index 0000000000000..81d8d3b48c656 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/signal/window_functions.h @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef BUILD_MS_EXPERIMENTAL_OPS + +namespace onnxruntime { +namespace contrib { + +class VariableOutputDataTypeBase : public OpKernel { + protected: + onnx::TensorProto_DataType data_type_; + + public: + VariableOutputDataTypeBase(const OpKernelInfo& info) : OpKernel(info) { + data_type_ = static_cast(info.GetAttrOrDefault("output_datatype", onnx::TensorProto_DataType::TensorProto_DataType_FLOAT)); + } +}; + +class HannWindow final : public VariableOutputDataTypeBase { + public: + explicit HannWindow(const OpKernelInfo& info) : VariableOutputDataTypeBase(info) { + } + Status Compute(OpKernelContext* ctx) const override; +}; + +class HammingWindow final : public VariableOutputDataTypeBase { + public: + explicit HammingWindow(const OpKernelInfo& info) : VariableOutputDataTypeBase(info) { + } + Status Compute(OpKernelContext* ctx) const override; +}; + +class BlackmanWindow final : public VariableOutputDataTypeBase { + public: + explicit BlackmanWindow(const OpKernelInfo& info) : VariableOutputDataTypeBase(info) { + } + Status Compute(OpKernelContext* ctx) const override; +}; + +class MelWeightMatrix final : public VariableOutputDataTypeBase { + public: + explicit MelWeightMatrix(const OpKernelInfo& info) : VariableOutputDataTypeBase(info) { + } + Status Compute(OpKernelContext* ctx) const override; +}; + +} // namespace contrib +} // namespace onnxruntime + +#endif \ No newline at end of file diff --git a/onnxruntime/core/framework/onnxruntime_typeinfo.cc b/onnxruntime/core/framework/onnxruntime_typeinfo.cc index 145f6fbab5161..7d130ccb36f76 100644 --- a/onnxruntime/core/framework/onnxruntime_typeinfo.cc +++ b/onnxruntime/core/framework/onnxruntime_typeinfo.cc @@ -348,4 +348,4 @@ OrtStatus* OrtTypeInfo::Clone(OrtTypeInfo** out) { break; } return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "not implemented"); -} \ No newline at end of file +} diff --git a/onnxruntime/core/framework/onnxruntime_typeinfo.h b/onnxruntime/core/framework/onnxruntime_typeinfo.h index 3c256aa73d17d..5b9145d32e28c 100644 --- a/onnxruntime/core/framework/onnxruntime_typeinfo.h +++ b/onnxruntime/core/framework/onnxruntime_typeinfo.h @@ -43,7 +43,6 @@ struct OrtTypeInfo { static OrtStatus* FromTypeProto(const ONNX_NAMESPACE::TypeProto*, OrtTypeInfo** out); static const onnxruntime::DataTypeImpl* ElementTypeFromProto(int type); - private: OrtTypeInfo(ONNXType type) noexcept; OrtTypeInfo(ONNXType type, OrtTensorTypeAndShapeInfo* data) noexcept; OrtTypeInfo(ONNXType type, OrtMapTypeInfo* map_type_info) noexcept; diff --git a/onnxruntime/core/framework/tensor_type_and_shape.cc b/onnxruntime/core/framework/tensor_type_and_shape.cc index c20945a767e0a..d00f37f3cb91a 100644 --- a/onnxruntime/core/framework/tensor_type_and_shape.cc +++ b/onnxruntime/core/framework/tensor_type_and_shape.cc @@ -251,4 +251,4 @@ ORT_API_STATUS_IMPL(OrtApis::GetTypeInfo, _In_ const OrtValue* v, _Outptr_result auto status = OrtTypeInfo::FromOrtValue(*v, out); return status; API_IMPL_END -} +} \ No newline at end of file diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 54bcbd2b091cc..fdda10f467a5e 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -11,6 +11,7 @@ #include "onnx/defs/shape_inference.h" #include "onnx/defs/tensor_proto_util.h" #include "core/mlas/inc/mlas.h" +#include "core/graph/signal_ops/signal_defs.h" namespace ONNX_NAMESPACE { void convPoolShapeInference( @@ -1122,7 +1123,7 @@ Sample echo operator.)DOC"); .SinceVersion(1) .SetDoc(R"DOC()DOC") .Input(0, "X", "input tensor", "T") - .Attr("signal_ndim", "", AttributeProto::INT) + .Attr("signal_ndim", "", AttributeProto::INT, static_cast(1)) .Attr("normalized", "", AttributeProto::INT, static_cast(0)) .Attr("onesided", "", AttributeProto::INT, static_cast(1)) .Output(0, "Y", "output tensor", "T") @@ -2372,6 +2373,11 @@ It's an extension of Gelu. It takes the sum of input A and bias input B as the i RegisterNchwcSchemas(); } RegisterBertSchemas(); + +#ifdef BUILD_MS_EXPERIMENTAL_OPS + onnxruntime::signal::RegisterSignalSchemas(); +#endif + RegisterQuantizationSchemas(); } } // namespace contrib diff --git a/onnxruntime/core/graph/signal_ops/signal_defs.cc b/onnxruntime/core/graph/signal_ops/signal_defs.cc new file mode 100644 index 0000000000000..382fe8b502da6 --- /dev/null +++ b/onnxruntime/core/graph/signal_ops/signal_defs.cc @@ -0,0 +1,324 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef BUILD_MS_EXPERIMENTAL_OPS + +#include "core/framework/tensorprotoutils.h" +#include "core/providers/common.h" +#include "core/graph/constants.h" +#include "core/graph/signal_ops/signal_defs.h" +#include "core/graph/op.h" +#include "onnx/defs/schema.h" +#include "onnx/defs/shape_inference.h" +#include "onnx/defs/tensor_proto_util.h" + + +namespace onnxruntime { +namespace signal { + +using ONNX_NAMESPACE::AttributeProto; +using ONNX_NAMESPACE::OpSchema; +using ONNX_NAMESPACE::OPTIONAL_VALUE; + +template +static T get_scalar_value_from_tensor(const ONNX_NAMESPACE::TensorProto* t) { + if (t == nullptr) { + return T{}; + } + + auto data_type = t->data_type(); + switch (data_type) { + case ONNX_NAMESPACE::TensorProto::FLOAT: + return static_cast(ONNX_NAMESPACE::ParseData(t).at(0)); + case ONNX_NAMESPACE::TensorProto::DOUBLE: + return static_cast(ONNX_NAMESPACE::ParseData(t).at(0)); + case ONNX_NAMESPACE::TensorProto::INT32: + return static_cast(ONNX_NAMESPACE::ParseData(t).at(0)); + case ONNX_NAMESPACE::TensorProto::INT64: + return static_cast(ONNX_NAMESPACE::ParseData(t).at(0)); + default: + ORT_THROW("Unsupported input data type of ", data_type); + } +} + +void RegisterSignalSchemas() { + MS_SIGNAL_OPERATOR_SCHEMA(DFT) + .SetDomain(kMSExperimentalDomain) + .SinceVersion(1) + .SetDoc(R"DOC(DFT)DOC") + .Attr("onesided", + "If True (default), only values for half of the fft size are returned because the real-to-complex Fourier transform satisfies the conjugate symmetry." + "The output tensor will return the first floor(n_fft/2) + 1 values from the DFT." + "Values can be 0 or 1.", + AttributeProto::AttributeType::AttributeProto_AttributeType_INT, + static_cast(0)) + .Input(0, + "input", + "For complex input, the following shape is expected: [batch_idx][n_fft][2]" + "The final dimension represents the real and imaginary parts of the value." + "For real input, the following shape is expected: [batch_idx][n_fft]" + "The first dimension is the batch dimension.", + "T") + .Output(0, + "output", + "The Fourier Transform of the input vector." + "If onesided is 1, [batch_idx][floor(n_fft/2)+1][2]" + "If onesided is 0, [batch_idx][n_fft][2]", + "T") + .TypeConstraint("T", {"tensor(float16)", "tensor(float)", "tensor(double)"}, "") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + int64_t ndim = 1; + + bool is_onesided = true; + auto attr_proto = ctx.getAttribute("onesided"); + if (attr_proto && attr_proto->has_i()) { + is_onesided = static_cast(attr_proto->i()); + } + + if (ctx.getInputType(0)->tensor_type().has_shape()) { + auto& input_shape = getInputShape(ctx, 0); + ONNX_NAMESPACE::TensorShapeProto result_shape = input_shape; + + if (is_onesided) { + auto n_fft = input_shape.dim(1).dim_value(); + result_shape.mutable_dim(1)->set_dim_value((n_fft >> 1) + 1); + } + + auto dim_size = static_cast(input_shape.dim_size()); + if (dim_size == ndim + 1) { // real input + result_shape.add_dim()->set_dim_value(2); // output is same shape, but with extra dim for 2 values (real/imaginary) + } else if (dim_size == ndim + 2) { // complex input, do nothing + } else { + fail_shape_inference( + "the input_shape must [batch_idx][n_fft] for real values or [batch_idx][n_fft][2] for complex values.") + } + updateOutputShape(ctx, 0, result_shape); + } + }); + ; + + MS_SIGNAL_OPERATOR_SCHEMA(IDFT) + .SetDomain(kMSExperimentalDomain) + .SinceVersion(1) + .SetDoc(R"DOC(IDFT)DOC") + .Input(0, + "input", + "A complex signal of dimension signal_ndim." + "The last dimension of the tensor should be 2," + "representing the real and imaginary components of complex numbers," + "and should have at least signal_ndim + 2 dimensions." + "The first dimension is the batch dimension.", + "T") + .Output(0, + "output", + "The inverse fourier transform of the input vector," + "using the same format as the input.", + "T") + .TypeConstraint("T", {"tensor(float16)", "tensor(float)", "tensor(double)"}, "") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + int64_t ndim = 1; + auto attr_proto = ctx.getAttribute("signal_ndim"); + if (attr_proto && attr_proto->has_i()) { + ndim = static_cast(attr_proto->i()); + } + + auto& input_shape = getInputShape(ctx, 0); + ONNX_NAMESPACE::TensorShapeProto result_shape = input_shape; + + auto dim_size = static_cast(input_shape.dim_size()); + if (dim_size == ndim + 1) { // real input + result_shape.add_dim()->set_dim_value(2); // output is same shape, but with extra dim for 2 values (real/imaginary) + } else if (dim_size == ndim + 2) { // complex input, do nothing + } else { + fail_shape_inference( + "the input_shape must have 1 + signal_ndim dimensions for real inputs, or 2 + signal_ndim dimensions for complex input.") + } + + updateOutputShape(ctx, 0, result_shape); + }); + + MS_SIGNAL_OPERATOR_SCHEMA(STFT) + .SetDomain(kMSExperimentalDomain) + .SinceVersion(1) + .SetDoc(R"DOC(STFT)DOC") + .Attr("onesided", + "If True (default), only values for half of the fft size are returned because the real-to-complex Fourier transform satisfies the conjugate symmetry." + "The output tensor will return the first floor(n_fft/2) + 1 values from the DFT." + "Values can be 0 or 1.", + AttributeProto::AttributeType::AttributeProto_AttributeType_INT, + static_cast(1)) + .Input(0, + "signal", + "A complex signal of dimension signal_ndim." + "The last dimension of the tensor should be 2," + "representing the real and imaginary components of complex numbers," + "and should have at least signal_ndim + 2 dimensions." + "The first dimension is the batch dimension.", + "T1") + .Input(1, + "window", + "A tensor representing the window that will be slid over the input signal.", + "T1", + OpSchema::FormalParameterOption::Optional) + .Input(2, + "frame_length", // frame_length, fft_length, pad_mode + "Size of the fft.", + "T2", + OpSchema::FormalParameterOption::Optional) + .Input(3, + "frame_step", + "The number of samples to step between successive DFTs.", + "T2") + .Output(0, + "output", + "The inverse fourier transform of the input vector," + "using the same format as the input.", + "T1") + .TypeConstraint("T1", {"tensor(float16)", "tensor(float)", "tensor(double)"}, "") + .TypeConstraint("T2", {"tensor(int64)"}, ""); + + // Window Functions + MS_SIGNAL_OPERATOR_SCHEMA(HannWindow) + .SetDomain(kMSExperimentalDomain) + .SinceVersion(1) + .SetDoc(R"DOC(HannWindow)DOC") + .Attr("output_datatype", + "The data type of the output tensor. " + "Strictly must be one of the types from DataType enum in TensorProto.", + AttributeProto::AttributeType::AttributeProto_AttributeType_INT, + static_cast(onnx::TensorProto_DataType::TensorProto_DataType_FLOAT)) + .Input(0, + "size", + "A scalar value indicating the length of the Hann Window.", + "T1") + .Output(0, + "output", + "A Hann Window with length: size.", + "T2") + .TypeConstraint("T1", {"tensor(int64)"}, "") + .TypeConstraint("T2", {"tensor(float)", "tensor(float16)", "tensor(double)", "tensor(uint8)", "tensor(uint16)", "tensor(uint32)", "tensor(uint64)", "tensor(int8)", "tensor(int16)", "tensor(int32)", "tensor(int64)"}, "") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + auto size = get_scalar_value_from_tensor(ctx.getInputData(0)); + if (size > 0) { + ONNX_NAMESPACE::TensorShapeProto result_shape; + result_shape.add_dim()->set_dim_value(size); + updateOutputShape(ctx, 0, result_shape); + } + + propagateElemTypeFromAttributeToOutput(ctx, "output_datatype", 0); + }); + + MS_SIGNAL_OPERATOR_SCHEMA(HammingWindow) + .SetDomain(kMSExperimentalDomain) + .SinceVersion(1) + .SetDoc(R"DOC(HammingWindow)DOC") + .Attr("output_datatype", + "The data type of the output tensor. " + "Strictly must be one of the types from DataType enum in TensorProto.", + AttributeProto::AttributeType::AttributeProto_AttributeType_INT, + static_cast(onnx::TensorProto_DataType::TensorProto_DataType_FLOAT)) + .Input(0, + "size", + "A scalar value indicating the length of the Hamming Window.", + "T1") + .Output(0, + "output", + "A Hamming Window with length: size.", + "T2") + .TypeConstraint("T1", {"tensor(int64)"}, "") + .TypeConstraint("T2", {"tensor(float)", "tensor(float16)", "tensor(double)", "tensor(uint8)", "tensor(uint16)", "tensor(uint32)", "tensor(uint64)", "tensor(int8)", "tensor(int16)", "tensor(int32)", "tensor(int64)"}, "") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + auto size = get_scalar_value_from_tensor(ctx.getInputData(0)); + if (size > 0) { + ONNX_NAMESPACE::TensorShapeProto result_shape; + result_shape.add_dim()->set_dim_value(size); + updateOutputShape(ctx, 0, result_shape); + } + propagateElemTypeFromAttributeToOutput(ctx, "output_datatype", 0); + }); + + MS_SIGNAL_OPERATOR_SCHEMA(BlackmanWindow) + .SetDomain(kMSExperimentalDomain) + .SinceVersion(1) + .SetDoc(R"DOC(BlackmanWindow)DOC") + .Attr("output_datatype", + "The data type of the output tensor. " + "Strictly must be one of the types from DataType enum in TensorProto.", + AttributeProto::AttributeType::AttributeProto_AttributeType_INT, + static_cast(onnx::TensorProto_DataType::TensorProto_DataType_FLOAT)) + .Input(0, + "size", + "A scalar value indicating the length of the Blackman Window.", + "T1") + .Output(0, + "output", + "A Blackman Window with length: size.", + "T2") + .TypeConstraint("T1", {"tensor(int64)"}, "") + .TypeConstraint("T2", {"tensor(float)", "tensor(float16)", "tensor(double)", "tensor(uint8)", "tensor(uint16)", "tensor(uint32)", "tensor(uint64)", "tensor(int8)", "tensor(int16)", "tensor(int32)", "tensor(int64)"}, "") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + auto size = get_scalar_value_from_tensor(ctx.getInputData(0)); + if (size > 0) { + ONNX_NAMESPACE::TensorShapeProto result_shape; + result_shape.add_dim()->set_dim_value(size); + updateOutputShape(ctx, 0, result_shape); + } + propagateElemTypeFromAttributeToOutput(ctx, "output_datatype", 0); + }); + + MS_SIGNAL_OPERATOR_SCHEMA(MelWeightMatrix) + .SetDomain(kMSExperimentalDomain) + .SinceVersion(1) + .SetDoc(R"DOC(MelWeightMatrix)DOC") + .Attr("output_datatype", + "The data type of the output tensor. " + "Strictly must be one of the types from DataType enum in TensorProto.", + AttributeProto::AttributeType::AttributeProto_AttributeType_INT, + static_cast(onnx::TensorProto_DataType::TensorProto_DataType_FLOAT)) + .Input(0, + "num_mel_bins", + "The number of bands in the mel spectrum.", + "T1") + .Input(1, + "dft_length", + "The size of the FFT.", + "T1") + .Input(2, + "sample_rate", + "", + "T1") + .Input(3, + "lower_edge_hertz", + "", + "T2") + .Input(4, + "upper_edge_hertz", + "", + "T2") + .Output(0, + "output", + "The MEL Matrix", + "T3") + .TypeConstraint("T1", {"tensor(int64)"}, "") + .TypeConstraint("T2", {"tensor(float)", "tensor(float16)", "tensor(double)"}, "") + .TypeConstraint("T3", {"tensor(float)", "tensor(float16)", "tensor(double)", "tensor(uint8)", "tensor(uint16)", "tensor(uint32)", "tensor(uint64)", "tensor(int8)", "tensor(int16)", "tensor(int32)", "tensor(int64)"}, "") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + auto num_mel_bins = get_scalar_value_from_tensor(ctx.getInputData(0)); + auto dft_length = get_scalar_value_from_tensor(ctx.getInputData(1)); + if (num_mel_bins > 0 && dft_length > 0) { + ONNX_NAMESPACE::TensorShapeProto result_shape; + // Figure out how to specify one-sided??? + result_shape.add_dim()->set_dim_value(static_cast(std::floor(dft_length / 2.f + 1))); + result_shape.add_dim()->set_dim_value(num_mel_bins); + updateOutputShape(ctx, 0, result_shape); + } + propagateElemTypeFromAttributeToOutput(ctx, "output_datatype", 0); + }); +} + +} // namespace audio +} // namespace onnxruntime + +#endif \ No newline at end of file diff --git a/onnxruntime/core/graph/signal_ops/signal_defs.h b/onnxruntime/core/graph/signal_ops/signal_defs.h new file mode 100644 index 0000000000000..503b6b8ff56f6 --- /dev/null +++ b/onnxruntime/core/graph/signal_ops/signal_defs.h @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/graph/onnx_protobuf.h" + +namespace onnxruntime { +namespace signal { +#define MS_SIGNAL_OPERATOR_SCHEMA(name) \ + MS_SIGNAL_OPERATOR_SCHEMA_UNIQ_HELPER(__COUNTER__, name) +#define MS_SIGNAL_OPERATOR_SCHEMA_UNIQ_HELPER(Counter, name) \ + MS_SIGNAL_OPERATOR_SCHEMA_UNIQ(Counter, name) +#define MS_SIGNAL_OPERATOR_SCHEMA_UNIQ(Counter, name) \ + static ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce( \ + op_schema_register_once##name##Counter) ONNX_UNUSED = \ + ONNX_NAMESPACE::OpSchema(#name, __FILE__, __LINE__) + +#define MS_SIGNAL_OPERATOR_SCHEMA_ELSEWHERE(name, schema_func) \ + MS_SIGNAL_OPERATOR_SCHEMA_UNIQ_HELPER_ELSEWHERE(__COUNTER__, name, schema_func) +#define MS_SIGNAL_OPERATOR_SCHEMA_UNIQ_HELPER_ELSEWHERE(Counter, name, schema_func) \ + MS_SIGNAL_OPERATOR_SCHEMA_UNIQ_ELSEWHERE(Counter, name, schema_func) +#define MS_SIGNAL_OPERATOR_SCHEMA_UNIQ_ELSEWHERE(Counter, name, schema_func) \ + static ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce( \ + op_schema_register_once##name##Counter) ONNX_UNUSED = \ + schema_func(ONNX_NAMESPACE::OpSchema(#name, __FILE__, __LINE__)) + +void RegisterSignalSchemas(); +} // namespace dml +} // namespace onnxruntime diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index 1007208fc7ece..25606dd644ec8 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -154,6 +154,7 @@ Status Environment::Initialize(std::unique_ptr logging_ // Register Microsoft domain with min/max op_set version as 1/1. std::call_once(schemaRegistrationOnceFlag, []() { ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance().AddDomainToVersion(onnxruntime::kMSDomain, 1, 1); + ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance().AddDomainToVersion(onnxruntime::kMSExperimentalDomain, 1, 1); ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance().AddDomainToVersion(onnxruntime::kMSNchwcDomain, 1, 1); ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance().AddDomainToVersion(onnxruntime::kMSFeaturizersDomain, 1, 1); #ifdef USE_DML diff --git a/onnxruntime/test/contrib_ops/sample_op_test.cc b/onnxruntime/test/contrib_ops/sample_op_test.cc index 53d47acc99cc6..bb1bdbd1da43d 100644 --- a/onnxruntime/test/contrib_ops/sample_op_test.cc +++ b/onnxruntime/test/contrib_ops/sample_op_test.cc @@ -18,4 +18,4 @@ TEST(MLOpTest, SampleOpFloat) { } } // namespace test -} // namespace onnxruntime +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/test/contrib_ops/signal_ops_test.cc b/onnxruntime/test/contrib_ops/signal_ops_test.cc new file mode 100644 index 0000000000000..3fe4ce75e604e --- /dev/null +++ b/onnxruntime/test/contrib_ops/signal_ops_test.cc @@ -0,0 +1,207 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef BUILD_MS_EXPERIMENTAL_OPS + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +static void TestNaiveDFTFloat(bool is_onesided) { + OpTester test("DFT", 1, onnxruntime::kMSExperimentalDomain); + + std::vector shape = {1, 5}; + std::vector output_shape = {1, 5, 2}; + output_shape[1] = is_onesided ? (1 + (shape[1] >> 1)) : shape[1]; + + std::vector input = {1, 2, 3, 4, 5}; + std::vector expected_output = { + 15.000000f, 0.0000000f, + -2.499999f, 3.4409550f, + -2.500000f, 0.8123000f, + -2.499999f, -0.812299f, + -2.500003f, -3.440953f + }; + + if (is_onesided) { + expected_output.resize(6); + } + test.AddInput("input", shape, input); + test.AddAttribute("onesided", static_cast(is_onesided)); + test.AddOutput("output", output_shape, expected_output); + test.Run(); +} + +static void TestRadix2DFTFloat(bool is_onesided) { + OpTester test("DFT", 1, onnxruntime::kMSExperimentalDomain); + + std::vector shape = {1, 8}; + std::vector output_shape = {1, 8, 2}; + output_shape[1] = is_onesided ? (1 + (shape[1] >> 1)) : shape[1]; + + std::vector input = {1, 2, 3, 4, 5, 6, 7, 8}; + std::vector expected_output = { + 36.000f, 0.000f, + -4.000f, 9.65685f, + -4.000f, 4.000f, + -4.000f, 1.65685f, + -4.000f, 0.000f, + -4.000f, -1.65685f, + -4.000f, -4.000f, + -4.000f, -9.65685f + }; + + if (is_onesided) { + expected_output.resize(10); + } + test.AddInput("input", shape, input); + test.AddAttribute("onesided", static_cast(is_onesided)); + test.AddOutput("output", output_shape, expected_output); + test.Run(); +} + +TEST(MLSignalOpTest, DFTFloat) { + TestNaiveDFTFloat(false); + TestNaiveDFTFloat(true); + TestRadix2DFTFloat(false); + TestRadix2DFTFloat(true); +} + +TEST(MLSignalOpTest, IDFTFloat) { + OpTester test("IDFT", 1, onnxruntime::kMSExperimentalDomain); + + std::vector shape = {1, 5, 2}; + std::vector input = + { + 15.000000f, 0.0000000f, + -2.499999f, 3.4409550f, + -2.500000f, 0.8123000f, + -2.499999f, -0.812299f, + -2.500003f, -3.440953f + }; + std::vector expected_output = + { + 1.000f, 0.000f, + 2.000f, 0.000f, + 3.000f, 0.000f, + 4.000f, 0.000f, + 5.000f, 0.000f + }; + + test.AddInput("input", shape, input); + test.AddOutput("output", shape, expected_output); + test.Run(); +} + +TEST(MLSignalOpTest, STFTFloat) { + OpTester test("STFT", 1, onnxruntime::kMSExperimentalDomain); + + std::vector signal(64, 1); + test.AddInput("signal", {1, 64}, signal); + std::vector window(16, 1); + test.AddInput("window", {16}, window); + test.AddInput("frame_length", {}, {16}); + test.AddInput("frame_step", {}, {8}); + + std::vector output_shape = {1, 7, 9, 2}; + std::vector expected_output = + { + 16.000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, + 16.000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, + 16.000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, + 16.000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, + 16.000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, + 16.000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, + 16.000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f + }; + test.AddOutput("output", output_shape, expected_output); + test.Run(); +} + +TEST(MLSignalOpTest, HannWindowFloat) { + OpTester test("HannWindow", 1, onnxruntime::kMSExperimentalDomain); + + std::vector scalar_shape = {}; + std::vector output_shape = {32}; + std::vector expected_output = + { + 0.000000f, 0.009607f, 0.038060f, 0.084265f, 0.146447f, 0.222215f, 0.308658f, 0.402455f, + 0.500000f, 0.597545f, 0.691342f, 0.777785f, 0.853553f, 0.915735f, 0.961940f, 0.990393f, + 1.000000f, 0.990393f, 0.961940f, 0.915735f, 0.853553f, 0.777785f, 0.691342f, 0.597545f, + 0.500000f, 0.402455f, 0.308658f, 0.222215f, 0.146447f, 0.084265f, 0.038060f, 0.009607f + }; + + test.AddInput("size", scalar_shape, {32}); + test.AddOutput("output", output_shape, expected_output); + test.Run(); +} + +TEST(MLSignalOpTest, HammingWindowFloat) { + OpTester test("HammingWindow", 1, onnxruntime::kMSExperimentalDomain); + + std::vector scalar_shape = {}; + std::vector output_shape = {32}; + std::vector expected_output = + { + 0.086957f, 0.095728f, 0.121707f, 0.163894f, 0.220669f, 0.289848f, 0.368775f, 0.454415f, + 0.543478f, 0.632541f, 0.718182f, 0.797108f, 0.866288f, 0.923062f, 0.965249f, 0.991228f, + 1.000000f, 0.991228f, 0.965249f, 0.923062f, 0.866288f, 0.797108f, 0.718182f, 0.632541f, + 0.543478f, 0.454415f, 0.368775f, 0.289848f, 0.220669f, 0.163894f, 0.121707f, 0.095728f + }; + + test.AddInput("size", scalar_shape, {32}); + test.AddOutput("output", output_shape, expected_output); + test.Run(); +} + +TEST(MLSignalOpTest, BlackmanWindowFloat) { + OpTester test("BlackmanWindow", 1, onnxruntime::kMSExperimentalDomain); + + std::vector scalar_shape = {}; + std::vector output_shape = {32}; + std::vector expected_output = + { + 0.000000f, 0.003518f, 0.014629f, 0.034880f, 0.066447f, 0.111600f, 0.172090f, 0.248544f, + 0.340000f, 0.443635f, 0.554773f, 0.667170f, 0.773553f, 0.866350f, 0.938508f, 0.984303f, + 1.000000f, 0.984303f, 0.938508f, 0.866350f, 0.773553f, 0.667170f, 0.554773f, 0.443635f, + 0.340000f, 0.248544f, 0.172090f, 0.111600f, 0.066447f, 0.034880f, 0.014629f, 0.003518f + }; + + test.AddInput("size", scalar_shape, {32}); + test.AddOutput("output", output_shape, expected_output); + test.Run(); +} + +TEST(MLSignalOpTest, MelWeightMatrixFloat) { + OpTester test("MelWeightMatrix", 1, onnxruntime::kMSExperimentalDomain); + + std::vector scalar_shape = {}; + std::vector output_shape = {9, 8}; + std::vector expected_output = + { + 1.000000f, 1.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, + 0.000000f, 0.000000f, 1.000000f, 1.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, + 0.000000f, 0.000000f, 0.000000f, 0.000000f, 1.000000f, 0.000000f, 0.000000f, 0.000000f, + 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 1.000000f, 0.000000f, 0.000000f, + 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 1.000000f, 0.000000f, + 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 1.000000f, + 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, + 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, + 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f + }; + + test.AddInput("num_mel_bins", scalar_shape, {8}); + test.AddInput("dft_length", scalar_shape, {16}); + test.AddInput("sample_rate", scalar_shape, {8192}); + test.AddInput("lower_edge_hertz", scalar_shape, {0}); + test.AddInput("upper_edge_hertz", scalar_shape, {8192 / 2.f}); + test.AddOutput("output", output_shape, expected_output); + test.Run(); +} + +} // namespace test +} // namespace onnxruntime + +#endif \ No newline at end of file diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index c985cda52014c..3fccffae59c6c 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -471,7 +471,8 @@ def parse_arguments(): # Code coverage parser.add_argument("--code_coverage", action='store_true', help="Generate code coverage when targetting Android (only).") - + parser.add_argument( + "--ms_experimental", action='store_true', help="Build microsoft experimental operators.") return parser.parse_args() @@ -690,6 +691,7 @@ def generate_build_tree(cmake_path, source_dir, build_dir, cuda_home, cudnn_home "ON" if args.enable_language_interop_ops else "OFF"), "-Donnxruntime_USE_DML=" + ("ON" if args.use_dml else "OFF"), "-Donnxruntime_USE_WINML=" + ("ON" if args.use_winml else "OFF"), + "-Donnxruntime_BUILD_MS_EXPERIMENTAL_OPS=" + ("ON" if args.ms_experimental else "OFF"), "-Donnxruntime_USE_TELEMETRY=" + ( "ON" if args.use_telemetry else "OFF"), "-Donnxruntime_ENABLE_LTO=" + ("ON" if args.enable_lto else "OFF"), diff --git a/tools/ci_build/github/azure-pipelines/templates/windowsai-nuget-build.yml b/tools/ci_build/github/azure-pipelines/templates/windowsai-nuget-build.yml index e6e7cc3727e86..7da626b62d2e2 100644 --- a/tools/ci_build/github/azure-pipelines/templates/windowsai-nuget-build.yml +++ b/tools/ci_build/github/azure-pipelines/templates/windowsai-nuget-build.yml @@ -70,7 +70,7 @@ steps: displayName: 'Generate CMake Configuration' inputs: scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' - arguments: '--build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --enable_onnx_tests $(TelemetryOption) --use_winml --cmake_generator "Visual Studio 16 2019" --update --config RelWithDebInfo --enable_lto --disable_rtti $(BuildFlags)' + arguments: '--build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --enable_onnx_tests $(TelemetryOption) --ms_experimental --use_winml --cmake_generator "Visual Studio 16 2019" --update --config RelWithDebInfo --enable_lto --disable_rtti $(BuildFlags)' workingDirectory: '$(Build.BinariesDirectory)' - ${{ if or(notIn(parameters['sln_platform'], 'Win32', 'x64'), eq(parameters.BuildForStore, 'true')) }}: diff --git a/tools/nuget/generate_nuspec_for_native_nuget.py b/tools/nuget/generate_nuspec_for_native_nuget.py index b3e352937fbb9..9e0e7375c52eb 100644 --- a/tools/nuget/generate_nuspec_for_native_nuget.py +++ b/tools/nuget/generate_nuspec_for_native_nuget.py @@ -278,6 +278,10 @@ def generate_files(list, args): files_list.append('') + # Process microsoft.ai.machinelearning.experimental.winmd + files_list.append('') if args.target_architecture == 'x64' and not args.is_store_build: interop_dll_path = 'Microsoft.AI.MachineLearning.Interop\\net5.0-windows10.0.19041.0' interop_dll = interop_dll_path + '\\Microsoft.AI.MachineLearning.Interop.dll' diff --git a/winml/adapter/winml_adapter_apis.h b/winml/adapter/winml_adapter_apis.h index 66114d2c5e499..66487b9525c55 100644 --- a/winml/adapter/winml_adapter_apis.h +++ b/winml/adapter/winml_adapter_apis.h @@ -38,6 +38,7 @@ ORT_API_STATUS(ModelGetOutputTypeInfo, _In_ const OrtModel* model, _In_ size_t i ORT_API_STATUS(ModelGetMetadataCount, _In_ const OrtModel* model, _Out_ size_t* count); ORT_API_STATUS(ModelGetMetadata, _In_ const OrtModel* model, _In_ size_t count, _Out_ const char** const key, _Out_ size_t* key_len, _Out_ const char** const value, _Out_ size_t* value_len); ORT_API_STATUS(ModelEnsureNoFloat16, _In_ const OrtModel* model); +ORT_API_STATUS(SaveModel, _In_ const OrtModel* in, _In_ const wchar_t* const file_name, _In_ size_t len); ORT_API_STATUS(OrtSessionOptionsAppendExecutionProviderEx_DML, _In_ OrtSessionOptions* options, _In_ ID3D12Device* d3d_device, _In_ ID3D12CommandQueue* cmd_queue, bool metacommands_enabled); @@ -80,6 +81,55 @@ ORT_API_STATUS(CreateCustomRegistry, _Out_ IMLOperatorRegistry** registry); ORT_API_STATUS(ValueGetDeviceId, _In_ OrtValue* ort_value, _Out_ int16_t* device_id); ORT_API_STATUS(SessionGetInputRequiredDeviceId, _In_ OrtSession* session, _In_ const char* const input_name, _Out_ int16_t* device_id); +// Model Building +ORT_API_STATUS(CreateTensorTypeInfo, _In_ const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type, _Out_ OrtTypeInfo** type_info); +ORT_API_STATUS(CreateSequenceTypeInfo, _Out_ OrtTypeInfo** type_info); +ORT_API_STATUS(CreateMapTypeInfo, _Out_ OrtTypeInfo** type_info); +ORT_API_STATUS(CreateModel, _In_ int64_t opset, _Outptr_ OrtModel** out); +ORT_API_STATUS(ModelAddInput, _In_ OrtModel* model, _In_ const char* const input_name, _In_ OrtTypeInfo* info); +ORT_API_STATUS(ModelAddConstantInput, _In_ OrtModel* model, _In_ const char* const input_name, _In_ OrtTypeInfo* info, _In_ OrtValue* value); +ORT_API_STATUS(ModelAddOutput, _In_ OrtModel* model, _In_ const char* const output_name, _In_ OrtTypeInfo* info); +ORT_API_STATUS(ModelAddOperator, + _In_ OrtModel* model, + _In_ const char* const op_type, + _In_ const char* const op_name, + _In_ int64_t opset, + _In_ const char* const op_domain, + _In_ const char* const* input_names, _In_ size_t num_inputs, + _In_ const char* const* output_names, _In_ size_t num_outputs, + _In_ const char* const* attribute_names, _In_ OrtValue** attribute_values, _In_ size_t num_attributes); + +ORT_API_STATUS(ModelGetOpsetVersion, _In_ OrtModel* model, _In_ const char* const domain, _Out_ int32_t* version); + +ORT_API_STATUS(OperatorGetNumInputs, + _In_ const char* const op_type, + _In_ int64_t opset, + _In_ const char* const op_domain, + _Out_ size_t* num_inputs); + +ORT_API_STATUS(OperatorGetInputName, + _In_ const char* const op_type, + _In_ int64_t opset, + _In_ const char* const op_domain, + _In_ size_t index, + _Out_ const char** const name); + +ORT_API_STATUS(OperatorGetNumOutputs, + _In_ const char* const op_type, + _In_ int64_t opset, + _In_ const char* const op_domain, + _Out_ size_t* num_inputs); + +ORT_API_STATUS(OperatorGetOutputName, + _In_ const char* const op_type, + _In_ int64_t opset, + _In_ const char* const op_domain, + _In_ size_t index, + _Out_ const char** const name); + +// maps and sequences??? +//ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange().Map().at(ONNX_NAMESPACE::ONNX_DOMAIN).second + } // namespace Adapter } // namespace MachineLearning } // namespace AI diff --git a/winml/adapter/winml_adapter_c_api.cpp b/winml/adapter/winml_adapter_c_api.cpp index c4f15ff401e26..5abf8e705dc61 100644 --- a/winml/adapter/winml_adapter_c_api.cpp +++ b/winml/adapter/winml_adapter_c_api.cpp @@ -40,6 +40,7 @@ static constexpr WinmlAdapterApi winml_adapter_api_1 = { &winmla::ModelGetMetadataCount, &winmla::ModelGetMetadata, &winmla::ModelEnsureNoFloat16, + &winmla::SaveModel, // OrtSessionOptions methods &OrtSessionOptionsAppendExecutionProvider_CPU, @@ -79,6 +80,20 @@ static constexpr WinmlAdapterApi winml_adapter_api_1 = { &winmla::ValueGetDeviceId, &winmla::SessionGetInputRequiredDeviceId, + &winmla::CreateTensorTypeInfo, + &winmla::CreateSequenceTypeInfo, + &winmla::CreateMapTypeInfo, + &winmla::CreateModel, + &winmla::ModelAddInput, + &winmla::ModelAddConstantInput, + &winmla::ModelAddOutput, + &winmla::ModelAddOperator, + &winmla::ModelGetOpsetVersion, + &winmla::OperatorGetNumInputs, + &winmla::OperatorGetInputName, + &winmla::OperatorGetNumOutputs, + &winmla::OperatorGetOutputName, + // Release &winmla::ReleaseModel }; diff --git a/winml/adapter/winml_adapter_c_api.h b/winml/adapter/winml_adapter_c_api.h index 75f7acef723f4..4a2b9ed26c493 100644 --- a/winml/adapter/winml_adapter_c_api.h +++ b/winml/adapter/winml_adapter_c_api.h @@ -202,6 +202,12 @@ struct WinmlAdapterApi { */ OrtStatus*(ORT_API_CALL* ModelEnsureNoFloat16)(_In_ const OrtModel* model)NO_EXCEPTION; + /** + * SaveModel + * This api save the model to the fiven file + */ + OrtStatus*(ORT_API_CALL* SaveModel)(_In_ const OrtModel* in, _In_ const wchar_t* const file_name, _In_ size_t len)NO_EXCEPTION; + // OrtSessionOptions methods /** @@ -425,5 +431,53 @@ struct WinmlAdapterApi { */ OrtStatus*(ORT_API_CALL* SessionGetInputRequiredDeviceId)(_In_ OrtSession* session, _In_ const char* const input_name, _Out_ int16_t* device_id)NO_EXCEPTION; + OrtStatus*(ORT_API_CALL* CreateTensorTypeInfo)(_In_ const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type, _Out_ OrtTypeInfo** type_info)NO_EXCEPTION; + OrtStatus*(ORT_API_CALL* CreateSequenceTypeInfo)(_Out_ OrtTypeInfo** type_info)NO_EXCEPTION; + OrtStatus*(ORT_API_CALL* CreateMapTypeInfo)(_Out_ OrtTypeInfo** type_info)NO_EXCEPTION; + + OrtStatus*(ORT_API_CALL* CreateModel)(_In_ int64_t opset, _Outptr_ OrtModel** out)NO_EXCEPTION; + OrtStatus*(ORT_API_CALL* ModelAddInput)(_In_ OrtModel* model, _In_ const char* const input_name, _In_ OrtTypeInfo* info)NO_EXCEPTION; + OrtStatus*(ORT_API_CALL* ModelAddConstantInput)(_In_ OrtModel* model, _In_ const char* const input_name, _In_ OrtTypeInfo* info, _In_ OrtValue* value)NO_EXCEPTION; + OrtStatus*(ORT_API_CALL* ModelAddOutput)(_In_ OrtModel* model, _In_ const char* const output_name, _In_ OrtTypeInfo* info)NO_EXCEPTION; + OrtStatus*(ORT_API_CALL* ModelAddOperator)( + _In_ OrtModel* model, + _In_ const char* const op_type, + _In_ const char* const op_name, + _In_ int64_t opset, + _In_ const char* const op_domain, + _In_ const char* const* input_names, _In_ size_t num_inputs, + _In_ const char* const* output_names, _In_ size_t num_outputs, + _In_ const char* const* attribute_names, _In_ OrtValue** attribute_values, _In_ size_t num_attributes)NO_EXCEPTION; + + OrtStatus*(ORT_API_CALL* ModelGetOpsetVersion)(_In_ OrtModel* model, _In_ const char* const domain, _Out_ int32_t* version)NO_EXCEPTION; + + OrtStatus*(ORT_API_CALL* OperatorGetNumInputs)( + _In_ const char* const op_type, + _In_ int64_t opset, + _In_ const char* const op_domain, + _Out_ size_t* num_inputs)NO_EXCEPTION; + + OrtStatus*(ORT_API_CALL* OperatorGetInputName)( + _In_ const char* const op_type, + _In_ int64_t opset, + _In_ const char* const op_domain, + _In_ size_t index, + _Out_ const char** const name + )NO_EXCEPTION; + + OrtStatus*(ORT_API_CALL* OperatorGetNumOutputs)( + _In_ const char* const op_type, + _In_ int64_t opset, + _In_ const char* const op_domain, + _Out_ size_t* num_inputs)NO_EXCEPTION; + + OrtStatus*(ORT_API_CALL* OperatorGetOutputName)( + _In_ const char* const op_type, + _In_ int64_t opset, + _In_ const char* const op_domain, + _In_ size_t index, + _Out_ const char** const name)NO_EXCEPTION; + ORT_CLASS_RELEASE(Model); }; diff --git a/winml/adapter/winml_adapter_model.cpp b/winml/adapter/winml_adapter_model.cpp index 929852f82b732..f5d1e02cea09a 100644 --- a/winml/adapter/winml_adapter_model.cpp +++ b/winml/adapter/winml_adapter_model.cpp @@ -18,9 +18,14 @@ #include "google/protobuf/io/zero_copy_stream_impl.h" #include "core/framework/onnxruntime_typeinfo.h" +#include "onnx/defs/schema.h" +#include "core/framework/tensor_type_and_shape.h" + +#include "onnx/onnx-ml.pb.h" + namespace winmla = Windows::AI::MachineLearning::Adapter; -static std::vector GetInitializers(const onnx::ModelProto& model_proto) { +static std::vector GetInitializers(const ONNX_NAMESPACE::ModelProto& model_proto) { std::vector initializers; auto& graph = model_proto.graph(); auto& graph_initializers = graph.initializer(); @@ -30,10 +35,10 @@ static std::vector GetInitializers(const onnx::ModelProto& model_pr return initializers; } -static std::vector GetInputsWithoutInitializers(const onnx::ModelProto& model_proto) { +static std::vector GetInputsWithoutInitializers(const ONNX_NAMESPACE::ModelProto& model_proto) { auto initializers = GetInitializers(model_proto); - std::vector inputs_without_initializers; + std::vector inputs_without_initializers; auto& graph = model_proto.graph(); auto& inputs = graph.input(); for (auto& input : inputs) { @@ -54,8 +59,8 @@ static std::vector GetInputsWithoutInitializers(con return inputs_without_initializers; } -static std::vector GetOutputs(const onnx::ModelProto& model_proto) { - std::vector outputs_with_name; +static std::vector GetOutputs(const ONNX_NAMESPACE::ModelProto& model_proto) { + std::vector outputs_with_name; auto& graph = model_proto.graph(); auto& outputs = graph.output(); for (auto& output : outputs) { @@ -68,7 +73,7 @@ static std::vector GetOutputs(const onnx::ModelProt class ModelInfo { public: - ModelInfo(const onnx::ModelProto* model_proto) { + ModelInfo(const ONNX_NAMESPACE::ModelProto* model_proto) { Initialize(model_proto); } @@ -80,12 +85,12 @@ class ModelInfo { std::string description_; int64_t version_; std::vector> model_metadata_; - std::vector input_features_; - std::vector output_features_; + std::vector input_features_; + std::vector output_features_; bool requires_float16_support_; private: - void Initialize(const onnx::ModelProto* model_proto) { + void Initialize(const ONNX_NAMESPACE::ModelProto* model_proto) { for (auto& prop : model_proto->metadata_props()) { model_metadata_.push_back(std::make_pair(prop.key(), prop.value())); } @@ -112,12 +117,12 @@ class ModelInfo { } }; -OrtModel::OrtModel(std::unique_ptr model_proto) : model_proto_(std::move(model_proto)), - model_info_(std::make_unique(model_proto_.get())) { +OrtModel::OrtModel(std::unique_ptr model_proto) : model_proto_(std::move(model_proto)), + model_info_(std::make_unique(model_proto_.get())) { } // factory methods for creating an ort model from a path -static OrtStatus* CreateModelProto(const char* path, std::unique_ptr& out) { +static OrtStatus* CreateModelProto(const char* path, std::unique_ptr& out) { int file_descriptor; auto path_str = std::string(path); @@ -144,7 +149,7 @@ static OrtStatus* CreateModelProto(const char* path, std::unique_ptr(new onnx::ModelProto()); + auto model_proto = std::unique_ptr(new ONNX_NAMESPACE::ModelProto()); auto parse_succeeded = model_proto->ParseFromZeroCopyStream(&stream); if (!parse_succeeded) { @@ -156,10 +161,18 @@ static OrtStatus* CreateModelProto(const char* path, std::unique_ptr(new ONNX_NAMESPACE::ModelProto()); + auto opsetimportproto = model_proto->add_opset_import(); + opsetimportproto->set_version(opset); + model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); + return OrtModel::CreateOrtModelFromProto(std::move(model_proto), model); +} + OrtStatus* OrtModel::CreateOrtModelFromPath(const char* path, size_t len, OrtModel** model) { ORT_UNUSED_PARAMETER(len); - std::unique_ptr model_proto; + std::unique_ptr model_proto; if (auto status = CreateModelProto(path, model_proto)) { return status; @@ -169,7 +182,7 @@ OrtStatus* OrtModel::CreateOrtModelFromPath(const char* path, size_t len, OrtMod } OrtStatus* OrtModel::CreateOrtModelFromData(void* data, size_t len, OrtModel** model) { - auto model_proto = std::unique_ptr(new onnx::ModelProto()); + auto model_proto = std::unique_ptr(new ONNX_NAMESPACE::ModelProto()); auto parse_succeeded = model_proto->ParseFromArray(data, static_cast(len)); if (!parse_succeeded) { @@ -179,7 +192,7 @@ OrtStatus* OrtModel::CreateOrtModelFromData(void* data, size_t len, OrtModel** m return OrtModel::CreateOrtModelFromProto(std::move(model_proto), model); } -OrtStatus* OrtModel::CreateOrtModelFromProto(std::unique_ptr&& model_proto, OrtModel** model) { +OrtStatus* OrtModel::CreateOrtModelFromProto(std::unique_ptr&& model_proto, OrtModel** model) { *model = new (std::nothrow) OrtModel(std::move(model_proto)); if (*model == nullptr) { return OrtApis::CreateStatus(ORT_ENGINE_ERROR, "Engine failed to create a model!"); @@ -192,11 +205,11 @@ const ModelInfo* OrtModel::UseModelInfo() const { return model_info_.get(); } -const ONNX_NAMESPACE::ModelProto* OrtModel::UseModelProto() const { +ONNX_NAMESPACE::ModelProto* OrtModel::UseModelProto() const { return model_proto_.get(); } -std::unique_ptr OrtModel::DetachModelProto() { +std::unique_ptr OrtModel::DetachModelProto() { return std::move(model_proto_); } @@ -220,7 +233,7 @@ ORT_API_STATUS_IMPL(winmla::CreateModelFromData, _In_ void* data, _In_ size_t si ORT_API_STATUS_IMPL(winmla::CloneModel, _In_ const OrtModel* in, _Outptr_ OrtModel** out) { API_IMPL_BEGIN - auto model_proto_copy = std::make_unique(*in->UseModelProto()); + auto model_proto_copy = std::make_unique(*in->UseModelProto()); if (auto status = OrtModel::CreateOrtModelFromProto(std::move(model_proto_copy), out)) { return status; } @@ -228,6 +241,26 @@ ORT_API_STATUS_IMPL(winmla::CloneModel, _In_ const OrtModel* in, _Outptr_ OrtMod API_IMPL_END } +ORT_API_STATUS_IMPL(winmla::SaveModel, const OrtModel* in, const wchar_t* const file_name, size_t len) { + API_IMPL_BEGIN + int fd; + std::wstring file_path = file_name; + Status status = onnxruntime::Env::Default().FileOpenWr(file_path, fd); + if (fd < 0) { + return OrtApis::CreateStatus(ORT_NO_SUCHFILE, "File not found!"); + } + + auto model_proto = in->UseModelProto(); + google::protobuf::io::FileOutputStream output(fd); + const bool success = model_proto->SerializeToZeroCopyStream(&output) && output.Flush(); + if (!success) { + return OrtApis::CreateStatus(ORT_RUNTIME_EXCEPTION, "Failed to serialize model!"); + } + output.Close(); + return nullptr; + API_IMPL_END +} + ORT_API_STATUS_IMPL(winmla::ModelGetAuthor, _In_ const OrtModel* model, _Out_ const char** const author, _Out_ size_t* len) { API_IMPL_BEGIN *author = model->UseModelInfo()->author_.c_str(); @@ -387,7 +420,7 @@ ORT_API_STATUS_IMPL(winmla::ModelEnsureNoFloat16, _In_ const OrtModel* model) { for (int attribIndex = 0; attribIndex < node.attribute_size(); attribIndex++) { auto attribute = node.attribute(attribIndex); if (attribute.name() == "to") { - if (attribute.i() == onnx::TensorProto::DataType::TensorProto_DataType_FLOAT16) { + if (attribute.i() == ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_FLOAT16) { std::stringstream error_message; error_message << "The model contains a 16-bit input (" << node.name().c_str() @@ -403,7 +436,7 @@ ORT_API_STATUS_IMPL(winmla::ModelEnsureNoFloat16, _In_ const OrtModel* model) { // tensors via the Cast (to float16) operator for (int i = 0; i < graph.initializer_size(); i++) { auto initializer = graph.initializer(i); - if (initializer.data_type() == onnx::TensorProto::DataType::TensorProto_DataType_FLOAT16) { + if (initializer.data_type() == ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_FLOAT16) { std::stringstream error_message; error_message << "The model contains a 16-bit input (" << initializer.name().c_str() @@ -430,6 +463,312 @@ ORT_API_STATUS_IMPL(winmla::ModelEnsureNoFloat16, _In_ const OrtModel* model) { API_IMPL_END } +ORT_API_STATUS_IMPL(winmla::CreateModel, int64_t opset, OrtModel** out) { + API_IMPL_BEGIN + return OrtModel::CreateEmptyModel(opset, out); + API_IMPL_END +} + +static ONNX_NAMESPACE::TensorProto_DataType ONNXTensorElementDataTypeToTensorProto_DataType(ONNXTensorElementDataType type) { + switch (type) { + case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + return ONNX_NAMESPACE::TensorProto_DataType_FLOAT; + case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + return ONNX_NAMESPACE::TensorProto_DataType_UINT8; + case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: + return ONNX_NAMESPACE::TensorProto_DataType_INT8; + case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: + return ONNX_NAMESPACE::TensorProto_DataType_UINT16; + case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: + return ONNX_NAMESPACE::TensorProto_DataType_INT16; + case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + return ONNX_NAMESPACE::TensorProto_DataType_INT32; + case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + return ONNX_NAMESPACE::TensorProto_DataType_INT64; + case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: + return ONNX_NAMESPACE::TensorProto_DataType_STRING; + case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: + return ONNX_NAMESPACE::TensorProto_DataType_BOOL; + case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + return ONNX_NAMESPACE::TensorProto_DataType_FLOAT16; + case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: + return ONNX_NAMESPACE::TensorProto_DataType_DOUBLE; + case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + return ONNX_NAMESPACE::TensorProto_DataType_UINT32; + case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: + return ONNX_NAMESPACE::TensorProto_DataType_UINT64; + default: + return ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED; + } +} + +static void CreateTypeProto_Tensor(ONNX_NAMESPACE::TypeProto_Tensor* mutable_tensor_type, const char* const name, + const int64_t* shape, size_t shape_len, ONNX_NAMESPACE::TensorProto_DataType data_type) { + mutable_tensor_type->set_elem_type(data_type); + + size_t dim_param = 0; + for (size_t i = 0; i < shape_len; i++) { + if (shape[i] == -1) { + std::ostringstream str; + str << name << dim_param++; + mutable_tensor_type->mutable_shape()->add_dim()->set_dim_param(str.str().c_str(), 1); + } else { + mutable_tensor_type->mutable_shape()->add_dim()->set_dim_value(shape[i]); + } + } + + if (shape_len > 0) { + mutable_tensor_type->mutable_shape()->mutable_dim(0)->set_denotation("DATA_BATCH"); + } +} + +ORT_API_STATUS_IMPL(winmla::ModelAddInput, _In_ OrtModel* model, _In_ const char* const input_name, _In_ OrtTypeInfo* info) { + API_IMPL_BEGIN + auto model_proto = model->UseModelProto(); + ONNX_NAMESPACE::GraphProto& graph = *model_proto->mutable_graph(); + ONNX_NAMESPACE::ValueInfoProto& input = *graph.add_input(); + input.set_name(input_name); + + if (info->type == ONNXType::ONNX_TYPE_TENSOR) { + auto num_dims = info->data->shape.NumDimensions(); + CreateTypeProto_Tensor( + input.mutable_type()->mutable_tensor_type(), + input_name, + (num_dims == 0) ? nullptr : &info->data->shape[0], + num_dims, + ONNXTensorElementDataTypeToTensorProto_DataType(info->data->type)); + } + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::ModelAddConstantInput, _In_ OrtModel* model, _In_ const char* const input_name, _In_ OrtTypeInfo* info, _In_ OrtValue* value) { + API_IMPL_BEGIN + auto model_proto = model->UseModelProto(); + ONNX_NAMESPACE::GraphProto& graph = *model_proto->mutable_graph(); + ONNX_NAMESPACE::TensorProto& input = *graph.add_initializer(); + input.set_name(input_name); + + auto num_dims = info->data->shape.NumDimensions(); + for (size_t i = 0; i < num_dims; i++) { + input.add_dims(info->data->shape[i]); + } + + input.set_data_type(ONNXTensorElementDataTypeToTensorProto_DataType(info->data->type)); + auto tensor = value->GetMutable(); + input.set_raw_data(tensor->DataRaw(), tensor->SizeInBytes()); + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::ModelAddOutput, _In_ OrtModel* model, _In_ const char* const output_name, _In_ OrtTypeInfo* info) { + API_IMPL_BEGIN + auto model_proto = model->UseModelProto(); + ONNX_NAMESPACE::GraphProto& graph = *model_proto->mutable_graph(); + ONNX_NAMESPACE::ValueInfoProto& output = *graph.add_output(); + output.set_name(output_name); + + if (info->type == ONNXType::ONNX_TYPE_TENSOR) { + CreateTypeProto_Tensor( + output.mutable_type()->mutable_tensor_type(), + output_name, + &info->data->shape[0], + info->data->shape.NumDimensions(), + ONNXTensorElementDataTypeToTensorProto_DataType(info->data->type)); + } + return nullptr; + API_IMPL_END +} + +static const onnx::OpSchema* GetSchema(const char* const op_type, int64_t opset, const char* const op_domain) { + std::string domain = onnx::ONNX_DOMAIN; + if (op_domain) { + domain = op_domain; + } + + auto registry = ONNX_NAMESPACE::OpSchemaRegistry::Instance(); + return registry->GetSchema(op_type, static_cast(opset), domain); +} + +ORT_API_STATUS_IMPL(winmla::ModelAddOperator, + _In_ OrtModel* model, + _In_ const char* const op_type, + _In_ const char* const op_name, + _In_ int64_t opset, + _In_ const char* const op_domain, + _In_ const char* const* input_names, _In_ size_t num_inputs, + _In_ const char* const* output_names, _In_ size_t num_outputs, + _In_ const char* const* attribute_names, _In_ OrtValue** attribute_values, _In_ size_t num_attributes) { + API_IMPL_BEGIN + auto model_proto = model->UseModelProto(); + ONNX_NAMESPACE::GraphProto& graph = *model_proto->mutable_graph(); + onnx::NodeProto& node = *graph.add_node(); + node.set_op_type(op_type); + node.set_name(op_name); + node.set_domain(op_domain); + + auto schema = GetSchema(op_type, opset, op_domain); + auto all_attributes = schema->attributes(); + + for (size_t i = 0; i < num_attributes; i++) { + auto tensor = attribute_values[i]->GetMutable(); + + auto attr = node.add_attribute(); + attr->set_name(attribute_names[i]); + auto& schema_attribute_definition = all_attributes.at(attribute_names[i]); + attr->set_type(schema_attribute_definition.type); + + switch (schema_attribute_definition.type) { + case onnx::AttributeProto_AttributeType_INT: { + if (tensor->Shape().Size() != 1) { + return OrtApis::CreateStatus(ORT_ENGINE_ERROR, "Expected a single int64 value!"); + } + auto raw_data = tensor->DataRaw(); + attr->set_i(*reinterpret_cast(raw_data)); + break; + } + case onnx::AttributeProto_AttributeType_FLOAT: { + if (tensor->Shape().Size() != 1) { + return OrtApis::CreateStatus(ORT_ENGINE_ERROR, "Expected a single float value!"); + } + auto raw_data = tensor->DataRaw(); + attr->set_f(*reinterpret_cast(raw_data)); + break; + } + case onnx::AttributeProto_AttributeType_INTS: { + auto raw_data = tensor->DataRaw(); + for (int j = 0; j < tensor->Shape().Size(); j++) { + attr->add_ints(*(reinterpret_cast(raw_data)+j)); + } + break; + } + case onnx::AttributeProto_AttributeType_FLOATS: { + auto raw_data = tensor->DataRaw(); + for (int j = 0; j < tensor->Shape().Size(); j++) { + attr->add_floats(*(reinterpret_cast(raw_data) + j)); + } + break; + } + case onnx::AttributeProto_AttributeType_TENSOR: { + auto tensor_proto = attr->add_tensors(); + auto prim_type = tensor->DataType()->AsPrimitiveDataType(); + if (prim_type == nullptr) { + return OrtApis::CreateStatus(ORT_ENGINE_ERROR, "Undefined tensor type!"); + } + tensor_proto->set_data_type(prim_type->GetDataType()); + tensor_proto->set_raw_data(tensor->DataRaw(), tensor->SizeInBytes()); + break; + } + } + } + + for (size_t i = 0; i < num_inputs; i++) { + auto name = input_names[i]; + if (name != nullptr) { + node.add_input(name); + } else { + node.add_input(); + } + } + + for (size_t i = 0; i < num_outputs; i++) { + auto name = output_names[i]; + if (name != nullptr) { + node.add_output(name); + } + else { + node.add_output("unused"); + } + } + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::ModelGetOpsetVersion, + _In_ OrtModel* model, + _In_ const char* const domain, + _Out_ int32_t* version) { + API_IMPL_BEGIN + auto model_proto = model->UseModelProto(); + + *version = -1; + auto size = static_cast(model_proto->opset_import_size()); + for (int i = 0; i < size; i++) { + auto& current_opset = model_proto->opset_import(i); + auto& current_domain = current_opset.domain(); + if (_strnicmp(domain, current_domain.c_str(), current_domain.size()) == 0) { + *version = static_cast(current_opset.version()); + break; + } + } + + return nullptr; + API_IMPL_END +} + ORT_API(void, winmla::ReleaseModel, OrtModel* ptr) { delete ptr; +} + +#include "core/framework/onnxruntime_typeinfo.h" +#include "core/framework/tensor_type_and_shape.h" + +OrtStatus* GetTensorShapeAndTypeHelper(ONNXTensorElementDataType type, const onnxruntime::TensorShape shape, const std::vector* dim_params, OrtTensorTypeAndShapeInfo** out); + +ORT_API_STATUS_IMPL(winmla::CreateTensorTypeInfo, _In_ const int64_t* dim_values, size_t dim_count, ONNXTensorElementDataType type, _Out_ OrtTypeInfo** ort_type_info) { + API_IMPL_BEGIN + OrtTensorTypeAndShapeInfo* data = nullptr; + auto tensor_shape = onnxruntime::TensorShape(dim_values, dim_count); + auto st = GetTensorShapeAndTypeHelper(type, tensor_shape, nullptr, &data); + if (st != nullptr){ + return st; + } + *ort_type_info = new OrtTypeInfo(ONNX_TYPE_TENSOR, data); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::CreateSequenceTypeInfo, _Out_ OrtTypeInfo** type_info) { + API_IMPL_BEGIN + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::CreateMapTypeInfo, _Out_ OrtTypeInfo** type_info) { + API_IMPL_BEGIN + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::OperatorGetNumInputs, _In_ const char* const op_type, _In_ int64_t opset, _In_ const char* const op_domain, _Out_ size_t* num_inputs) { + API_IMPL_BEGIN + auto schema = GetSchema(op_type, opset, op_domain); + *num_inputs = schema->inputs().size(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::OperatorGetInputName, _In_ const char* const op_type, _In_ int64_t opset, _In_ const char* const op_domain, _In_ size_t index, _Out_ const char** const name) { + API_IMPL_BEGIN + auto schema = GetSchema(op_type, opset, op_domain); + *name = schema->inputs().at(index).GetName().c_str(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::OperatorGetNumOutputs, _In_ const char* const op_type, _In_ int64_t opset, _In_ const char* const op_domain, _Out_ size_t* num_outputs) { + API_IMPL_BEGIN + auto schema = GetSchema(op_type, opset, op_domain); + *num_outputs = schema->outputs().size(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::OperatorGetOutputName, _In_ const char* const op_type, _In_ int64_t opset, _In_ const char* const op_domain, _In_ size_t index, _Out_ const char** const name) { + API_IMPL_BEGIN + auto schema = GetSchema(op_type, opset, op_domain); + *name = schema->outputs().at(index).GetName().c_str(); + return nullptr; + API_IMPL_END } \ No newline at end of file diff --git a/winml/adapter/winml_adapter_model.h b/winml/adapter/winml_adapter_model.h index 1dabfa23dfa0b..db68e58bc9ac5 100644 --- a/winml/adapter/winml_adapter_model.h +++ b/winml/adapter/winml_adapter_model.h @@ -11,12 +11,13 @@ class ModelInfo; struct OrtModel { public: + static OrtStatus* CreateEmptyModel(int64_t opset, OrtModel** model); static OrtStatus* CreateOrtModelFromPath(const char* path, size_t len, OrtModel** model); static OrtStatus* CreateOrtModelFromData(void* data, size_t len, OrtModel** model); static OrtStatus* CreateOrtModelFromProto(std::unique_ptr&& model_proto, OrtModel** model); const ModelInfo* UseModelInfo() const; - const onnx::ModelProto* UseModelProto() const; + onnx::ModelProto* UseModelProto() const; std::unique_ptr DetachModelProto(); private: diff --git a/winml/api/Microsoft.AI.MachineLearning.Experimental.idl b/winml/api/Microsoft.AI.MachineLearning.Experimental.idl index 6eeb6d8b71cb0..f9806db92b43f 100644 --- a/winml/api/Microsoft.AI.MachineLearning.Experimental.idl +++ b/winml/api/Microsoft.AI.MachineLearning.Experimental.idl @@ -20,17 +20,8 @@ import "Windows.AI.MachineLearning.idl"; #define ROOT_NS Microsoft #endif - namespace ROOT_NS.AI.MachineLearning.Experimental { - - [threading(both)] - [marshaling_behavior(agile)] - [dualapipartition(1)] - runtimeclass Dummy { - Dummy(); - - void Test(); - } + runtimeclass LearningModelBuilder; [marshaling_behavior(agile)] [dualapipartition(1)] @@ -46,4 +37,68 @@ namespace ROOT_NS.AI.MachineLearning.Experimental { LearningModelSessionOptionsExperimental Options { get; }; } -} // namespace Microsoft.AI.MachineLearning.Experimental \ No newline at end of file + [threading(both)] + [marshaling_behavior(agile)] + [dualapipartition(1)] + runtimeclass LearningModelOperator { + LearningModelOperator(String type); + LearningModelOperator(String type, String domain); + + LearningModelOperator SetName(String name); + LearningModelOperator SetInput(String operator_input_name, String input_name); + LearningModelOperator SetConstant(String operator_input_name, IInspectable default_value); + LearningModelOperator SetOutput(String operator_output_name, String output_name); + LearningModelOperator SetAttribute(String name, IInspectable value); + + String Name { get; }; + String Type { get; }; + String Domain { get; }; + } + + [marshaling_behavior(agile)] + [dualapipartition(1)] + runtimeclass LearningModelOperatorSet { + LearningModelBuilder Add(LearningModelOperator op); + } + + [marshaling_behavior(agile)] + [dualapipartition(1)] + runtimeclass LearningModelInputs { + LearningModelBuilder Add(ROOT_NS.AI.MachineLearning.ILearningModelFeatureDescriptor input); + LearningModelBuilder Add(String input_name, String input_description, IInspectable default_value); + LearningModelBuilder AddConstant(String input_name, IInspectable value); + } + + [marshaling_behavior(agile)] + [dualapipartition(1)] + runtimeclass LearningModelOutputs { + LearningModelBuilder Add(ROOT_NS.AI.MachineLearning.ILearningModelFeatureDescriptor output); + } + + //! \interface LearningModelBuilder + //! \brief Represents a trained machine learning model. + //! \details This is the main object you use to interact with Windows Machine Learning. You use + //! it to load, bind, and evaluate trained ONNX models. To load the model you use + //! one of the Load constructors. You can then enumerate the InputFeatures and + //! OutputFeatures. To bind and evaluate you create a LearningModelSession. + [threading(both)] + [marshaling_behavior(agile)] + [dualapipartition(1)] + runtimeclass LearningModelBuilder { + LearningModelInputs Inputs { get; }; + LearningModelOutputs Outputs { get; }; + LearningModelOperatorSet Operators { get; }; + + //! Create a builder. + static LearningModelBuilder Create(Int32 opset); + + //! Creates a TensorFeatureDescriptor.. this should be a constructor on the TFD + //TensorFeatureDescriptor(String name, String description, TensorKind kind, Int64[] shape); + static ROOT_NS.AI.MachineLearning.TensorFeatureDescriptor CreateTensorFeatureDescriptor(String name, String description, ROOT_NS.AI.MachineLearning.TensorKind kind, Int64[] shape); + static ROOT_NS.AI.MachineLearning.TensorFeatureDescriptor CreateTensorFeatureDescriptor(String name, ROOT_NS.AI.MachineLearning.TensorKind kind, Int64[] shape); + + ROOT_NS.AI.MachineLearning.LearningModel CreateModel(); + + void Save(String file_name); + } +} // namespace Microsoft.AI.MachineLearning.Experimental diff --git a/winml/dll/module.cpp b/winml/dll/module.cpp index 0f6745d86739c..e38dcc3cf3214 100644 --- a/winml/dll/module.cpp +++ b/winml/dll/module.cpp @@ -9,7 +9,8 @@ #ifndef BUILD_INBOX -#include "Dummy.h" +#include "LearningModelBuilder.h" +#include "LearningModelOperator.h" #include "LearningModelSessionOptionsExperimental.h" #include "LearningModelSessionExperimental.h" @@ -93,14 +94,24 @@ STDAPI DllGetExperimentalActivationFactory(void* classId, void** factory) noexce return std::equal(left.rbegin(), left.rend(), right.rbegin(), right.rend()); }; - winrt::hstring winml_namespace = winrt::to_hstring(XSTRINGIFY(WINML_ROOT_NS)); + std::wostringstream learning_model_builder_class; + learning_model_builder_class << XSTRINGIFY(WINML_ROOT_NS) << ".AI.MachineLearning.Experimental.LearningModelBuilder"; + if (requal(name, learning_model_builder_class.str())) { + *factory = winrt::detach_abi(winrt::make()); + return 0; + } + + std::wostringstream learning_model_operator_class; + learning_model_operator_class << XSTRINGIFY(WINML_ROOT_NS) << ".AI.MachineLearning.Experimental.LearningModelOperator"; + if (requal(name, learning_model_operator_class.str())) { + *factory = winrt::detach_abi(winrt::make()); - if (requal(name, winml_namespace + L".AI.MachineLearning.Experimental.Dummy")) { - *factory = winrt::detach_abi(winrt::make()); return 0; } - if (requal(name, winml_namespace + L".AI.MachineLearning.Experimental.LearningModelSessionExperimental")) { + std::wostringstream learning_model_session_experimental_class; + learning_model_session_experimental_class << XSTRINGIFY(WINML_ROOT_NS) << ".AI.MachineLearning.Experimental.LearningModelSessionExperimental"; + if (requal(name, learning_model_session_experimental_class.str())) { *factory = winrt::detach_abi(winrt::make()); return 0; } @@ -122,4 +133,4 @@ STDAPI DllGetActivationFactory(HSTRING classId, void** factory) { #endif return ret; -} \ No newline at end of file +} diff --git a/winml/lib/Api.Experimental/Dummy.cpp b/winml/lib/Api.Experimental/Dummy.cpp deleted file mode 100644 index ee57f04cf5df4..0000000000000 --- a/winml/lib/Api.Experimental/Dummy.cpp +++ /dev/null @@ -1,12 +0,0 @@ -#include "pch.h" -#include "Dummy.h" - -namespace WINML_EXPERIMENTALP { - -void Dummy::Test() -{ - throw hresult_not_implemented(); -} - -} // namespace WINML_EXPERIMENTALP - diff --git a/winml/lib/Api.Experimental/Dummy.h b/winml/lib/Api.Experimental/Dummy.h deleted file mode 100644 index f72b8362dafab..0000000000000 --- a/winml/lib/Api.Experimental/Dummy.h +++ /dev/null @@ -1,22 +0,0 @@ -#pragma once - -#include "Dummy.g.h" - -namespace WINML_EXPERIMENTALP { - -struct Dummy : DummyT -{ - Dummy() = default; - - void Test(); -}; - -} - -namespace WINML_EXPERIMENTAL::factory_implementation { - -struct Dummy : DummyT -{ -}; - -} diff --git a/winml/lib/Api.Experimental/LearningModelBuilder.cpp b/winml/lib/Api.Experimental/LearningModelBuilder.cpp new file mode 100644 index 0000000000000..79567c3ca0248 --- /dev/null +++ b/winml/lib/Api.Experimental/LearningModelBuilder.cpp @@ -0,0 +1,79 @@ +#include "pch.h" +#include "LearningModelBuilder.h" +#include "LearningModel.h" +#include "TensorFeatureDescriptor.h" +#include "LearningModelSession.h" +#include "LearningModelInputs.h" +#include "LearningModelOutputs.h" +#include "LearningModelOperatorSet.h" +#include "OnnxruntimeProvider.h" + +namespace WINML_EXPERIMENTALP { + +LearningModelBuilder::LearningModelBuilder(int64_t opset) : inputs_(nullptr), outputs_(nullptr), operators_(nullptr), inert_session_(nullptr) { + WINML_THROW_IF_FAILED(CreateOnnxruntimeEngineFactory(engine_factory_.put())); + WINML_THROW_IF_FAILED(engine_factory_->CreateEmptyModel(opset, model_.put())); + inputs_ = winrt::make(*this); + outputs_ = winrt::make(*this); + operators_ = winrt::make(*this); + + winrt::com_ptr<_winml::IEngineBuilder> builder; + WINML_THROW_IF_FAILED(engine_factory_->CreateEngineBuilder(builder.put())); + winrt::com_ptr<_winml::IEngine> engine; + WINML_THROW_IF_FAILED(builder->CreateEngine(engine.put())); + inert_session_ = winmlp::LearningModelSession::CreateInertSession(engine.get()); +} + +LearningModelBuilder::LearningModelBuilder(LearningModelBuilder& builder) : inputs_(builder.inputs_), + outputs_(builder.outputs_), + operators_(builder.operators_), + inert_session_(nullptr) +{ +} + +winml_experimental::LearningModelInputs LearningModelBuilder::Inputs() { + return inputs_; +} + +winml_experimental::LearningModelOutputs LearningModelBuilder::Outputs() { + return outputs_; +} + +winml_experimental::LearningModelOperatorSet LearningModelBuilder::Operators() { + return operators_; +} + +winml::LearningModel LearningModelBuilder::CreateModel() { + com_ptr<_winml::IModel> model_clone; + model_->CloneModel(model_clone.put()); + return winrt::make(engine_factory_.get(), model_clone.get(), nullptr); +} + +void LearningModelBuilder::Save(const winrt::hstring& file_name) { + model_->SaveModel(file_name.c_str(), file_name.size()); +} + +winml_experimental::LearningModelBuilder LearningModelBuilder::Create(int32_t opset) { + return winrt::make(static_cast(opset)); +} + +winml::TensorFeatureDescriptor LearningModelBuilder::CreateTensorFeatureDescriptor( + hstring const& name, + winml::TensorKind const& kind, + array_view shape) { + return winrt::make(name, L"", kind, shape); +} + +winml::TensorFeatureDescriptor LearningModelBuilder::CreateTensorFeatureDescriptor( + hstring const& name, + hstring const& description, + winml::TensorKind const& kind, + array_view shape) { + return winrt::make(name, description, kind, shape); +} + +_winml::IModel* LearningModelBuilder::UseModel() { + return model_.get(); +} + +} // namespace WINML_EXPERIMENTALP diff --git a/winml/lib/Api.Experimental/LearningModelBuilder.h b/winml/lib/Api.Experimental/LearningModelBuilder.h new file mode 100644 index 0000000000000..1987e3bbbec74 --- /dev/null +++ b/winml/lib/Api.Experimental/LearningModelBuilder.h @@ -0,0 +1,52 @@ +#pragma once + +#include +#include "LearningModelBuilder.g.h" +#include "iengine.h" + +namespace WINML_EXPERIMENTALP { + +struct LearningModelBuilder : LearningModelBuilderT { + LearningModelBuilder(int64_t opset); + LearningModelBuilder(LearningModelBuilder& builder); + + winml_experimental::LearningModelInputs Inputs(); + winml_experimental::LearningModelOutputs Outputs(); + winml_experimental::LearningModelOperatorSet Operators(); + winml::LearningModel CreateModel(); + void Save(const winrt::hstring& file_name); + + static winml_experimental::LearningModelBuilder Create(int32_t opset); + + static winml::TensorFeatureDescriptor CreateTensorFeatureDescriptor( + hstring const& name, + hstring const& description, + winml::TensorKind const& kind, + array_view shape); + + static winml::TensorFeatureDescriptor CreateTensorFeatureDescriptor( + hstring const& name, + winml::TensorKind const& kind, + array_view shape); + + _winml::IModel* UseModel(); + + winml::LearningModelSession InertSession() { + return inert_session_; + } + + private: + com_ptr<_winml::IEngineFactory> engine_factory_; + winml::LearningModelSession inert_session_; + com_ptr<_winml::IModel> model_; + + winml_experimental::LearningModelInputs inputs_; + winml_experimental::LearningModelOutputs outputs_; + winml_experimental::LearningModelOperatorSet operators_; +}; +} // WINML_EXPERIMENTALP + +namespace WINML_EXPERIMENTAL::factory_implementation { +struct LearningModelBuilder : LearningModelBuilderT { +}; +} // namespace winrt::winml_experimental::factory_implementation diff --git a/winml/lib/Api.Experimental/LearningModelInputs.cpp b/winml/lib/Api.Experimental/LearningModelInputs.cpp new file mode 100644 index 0000000000000..e2585edf995a0 --- /dev/null +++ b/winml/lib/Api.Experimental/LearningModelInputs.cpp @@ -0,0 +1,69 @@ +#include "pch.h" +#include "LearningModelInputs.h" +#include "LearningModelOperator.h" +#include "LearningModelSession.h" +#include "LearningModelBuilder.h" +#include "TensorFeatureDescriptor.h" + +#include "..\Api\inc\ILotusValueProviderPrivate.h" + +namespace WINML_EXPERIMENTALP { + +LearningModelInputs::LearningModelInputs(winml_experimental::LearningModelBuilder builder) : builder_(builder), + input_descriptors_(winrt::single_threaded_vector()), + input_default_values_(winrt::single_threaded_vector()), + constant_descriptors_(winrt::single_threaded_vector()), + constant_values_(winrt::single_threaded_vector()) { +} + +winml_experimental::LearningModelBuilder LearningModelInputs::AddInput(winml::ILearningModelFeatureDescriptor const& input, Windows::Foundation::IInspectable const& default_value, bool is_constant) { + // Perform model update inside the builder + auto model = builder_.as()->UseModel(); + auto descriptor_provider = input.as<_winml::IDescriptorInfoProvider>(); + auto input_name = _winml::Strings::UTF8FromHString(input.Name()); + winrt::com_ptr<_winml::IValue> default_value_ivalue; + + if (default_value) { + auto default_value_value_provider = default_value.as<_winml::ILotusValueProviderPrivate>(); + // Create the Binding Context to pass to the feature value + _winml::BindingContext context{ + _winml::BindingType::kInput, + builder_.as()->InertSession(), + nullptr, + nullptr, + {} // SubresourceId is set by callee + }; + default_value_value_provider->GetValue(context, default_value_ivalue.put()); + } + + model->AddModelInput(input_name.c_str(), descriptor_provider.get(), is_constant, default_value_ivalue.get()); + + return builder_; +} + +winml_experimental::LearningModelBuilder LearningModelInputs::Add(winml::ILearningModelFeatureDescriptor const& input) { + return AddInput(input, nullptr, false); +} + +winml_experimental::LearningModelBuilder LearningModelInputs::Add(hstring const& input_name, hstring const& input_description, Windows::Foundation::IInspectable const& default_value) { + if (auto tensor = default_value.try_as()) { + auto shape = tensor.Shape(); + std::vector shape_vector(begin(shape), end(shape)); + auto descriptor = winrt::make(input_name, input_description, tensor.TensorKind(), shape_vector); + return AddInput(descriptor, default_value, false); + } + WINML_THROW_HR(E_UNEXPECTED); +} + +winml_experimental::LearningModelBuilder LearningModelInputs::AddConstant(hstring const& input_name, Windows::Foundation::IInspectable const& value) { + if (auto tensor = value.try_as()) { + winrt::hstring no_description_for_constants = L""; + auto shape = tensor.Shape(); + std::vector shape_vector(begin(shape), end(shape)); + auto descriptor = winrt::make(input_name, no_description_for_constants, tensor.TensorKind(), shape_vector); + return AddInput(descriptor, value, true); + } + WINML_THROW_HR(E_UNEXPECTED); +} + +} // namespace WINML_EXPERIMENTALP \ No newline at end of file diff --git a/winml/lib/Api.Experimental/LearningModelInputs.h b/winml/lib/Api.Experimental/LearningModelInputs.h new file mode 100644 index 0000000000000..e6f4442eba3a7 --- /dev/null +++ b/winml/lib/Api.Experimental/LearningModelInputs.h @@ -0,0 +1,23 @@ +#pragma once + +#include "LearningModelInputs.g.h" +#include "LearningModelBuilder.h" + +namespace WINML_EXPERIMENTALP { + +struct LearningModelInputs : LearningModelInputsT { + LearningModelInputs(winml_experimental::LearningModelBuilder builder); + + winml_experimental::LearningModelBuilder Add(winml::ILearningModelFeatureDescriptor const& input); + winml_experimental::LearningModelBuilder Add(hstring const& input_name, hstring const& input_description, Windows::Foundation::IInspectable const& default_value); + winml_experimental::LearningModelBuilder AddConstant(hstring const& input_name, Windows::Foundation::IInspectable const& value); + winml_experimental::LearningModelBuilder AddInput(winml::ILearningModelFeatureDescriptor const& input, Windows::Foundation::IInspectable const& default_value, bool is_constant); + + private: + wfc::IVector input_descriptors_; + wfc::IVector input_default_values_; + wfc::IVector constant_descriptors_; + wfc::IVector constant_values_; + winml_experimental::LearningModelBuilder builder_; +}; +} // namespace WINML_EXPERIMENTALP \ No newline at end of file diff --git a/winml/lib/Api.Experimental/LearningModelOperator.cpp b/winml/lib/Api.Experimental/LearningModelOperator.cpp new file mode 100644 index 0000000000000..ba35215048a5c --- /dev/null +++ b/winml/lib/Api.Experimental/LearningModelOperator.cpp @@ -0,0 +1,92 @@ +#include "pch.h" +#include "LearningModelOperator.h" + +namespace WINML_EXPERIMENTALP { + +static uint32_t c_operator_index = 0; + +LearningModelOperator::LearningModelOperator(hstring const& type) : + LearningModelOperator(type, L"") +{} + +LearningModelOperator::LearningModelOperator(hstring const& type, hstring const& domain) : + type_(type), + domain_(domain) { + constant_input_mapping_ = winrt::single_threaded_map(); + input_mapping_ = winrt::single_threaded_map(); + output_mapping_ = winrt::single_threaded_map(); + attribute_values_ = winrt::single_threaded_map(); + + SetName(L""); +} + +winml_experimental::LearningModelOperator LearningModelOperator::SetName(hstring const& name) { + if (name.empty()) { + std::wostringstream name_stream; + name_stream << type_.c_str() << "_" << c_operator_index++; + name_ = name_stream.str().c_str(); + } else { + name_ = name; + } + return *this; +} + +winml_experimental::LearningModelOperator LearningModelOperator::SetInput( + hstring const& operator_input_name, hstring const& input_name) { + + // TODO Validate against allowed operator input NAMES. The types are not deduced. + input_mapping_.Insert(operator_input_name, input_name); + return *this; +} + +winml_experimental::LearningModelOperator LearningModelOperator::SetConstant( + hstring const& operator_input_name, wf::IInspectable const& value) { + // TODO Validate against allowed operator input NAMES. The types are not deduced. + auto constant_name = name_ + L"." + operator_input_name; + input_mapping_.Insert(operator_input_name, constant_name); + constant_input_mapping_.Insert(constant_name, value); + return *this; +} + +winml_experimental::LearningModelOperator LearningModelOperator::SetOutput( + hstring const& operator_output_name, hstring const& output_name) { + // TODO Validate against allowed operator output NAMES. The types are not deduced. + output_mapping_.Insert(operator_output_name, output_name); + return *this; +} + +winml_experimental::LearningModelOperator LearningModelOperator::SetAttribute( + hstring const& name, Windows::Foundation::IInspectable const& value) { + attribute_values_.Insert(name, value); + return *this; +} + +hstring LearningModelOperator::Name() { + return name_; +} + +hstring LearningModelOperator::Type() { + return type_; +} + +hstring LearningModelOperator::Domain() { + return domain_; +} + +wfc::IMap LearningModelOperator::InputMapping(){ + return input_mapping_; +} + +wfc::IMap LearningModelOperator::ConstantInputMapping() { + return constant_input_mapping_; +} + +wfc::IMap LearningModelOperator::OutputMapping() { + return output_mapping_; +} + +wfc::IMap LearningModelOperator::AttributeMap() { + return attribute_values_; +} + +} // namespace WINML_EXPERIMENTALP \ No newline at end of file diff --git a/winml/lib/Api.Experimental/LearningModelOperator.h b/winml/lib/Api.Experimental/LearningModelOperator.h new file mode 100644 index 0000000000000..69d2b5aa98615 --- /dev/null +++ b/winml/lib/Api.Experimental/LearningModelOperator.h @@ -0,0 +1,50 @@ +#pragma once + +#include "LearningModelOperator.g.h" +#include "TensorFeatureDescriptor.h" +#include "iengine.h" +#include "LearningModelBuilder.h" +#include "LearningModelInputs.h" + +namespace WINML_EXPERIMENTALP { + +struct LearningModelOperator : LearningModelOperatorT +{ + LearningModelOperator() = delete; + LearningModelOperator(hstring const& type); + LearningModelOperator(hstring const& type, hstring const& domain); + + winml_experimental::LearningModelOperator SetName(hstring const& name); + winml_experimental::LearningModelOperator SetInput(hstring const& operator_input_name, hstring const& input_name); + winml_experimental::LearningModelOperator SetConstant(hstring const& operator_input_name, wf::IInspectable const& value); + winml_experimental::LearningModelOperator SetOutput(hstring const& operator_output_name, hstring const& output_name); + winml_experimental::LearningModelOperator SetAttribute(hstring const& name, wf::IInspectable const& value); + hstring Name(); + hstring Type(); + hstring Domain(); + + wfc::IMap InputMapping(); + wfc::IMap ConstantInputMapping(); + wfc::IMap OutputMapping(); + wfc::IMap AttributeMap(); + +private: + winrt::hstring name_; + winrt::hstring domain_; + winrt::hstring type_; + + wfc::IMap attribute_values_; + wfc::IMap constant_input_mapping_; + wfc::IMap input_mapping_; + wfc::IMap output_mapping_; +}; + +} // namespace WINML_EXPERIMENTALP + +namespace WINML_EXPERIMENTAL::factory_implementation { + +struct LearningModelOperator : LearningModelOperatorT +{ +}; + +} diff --git a/winml/lib/Api.Experimental/LearningModelOperatorSet.cpp b/winml/lib/Api.Experimental/LearningModelOperatorSet.cpp new file mode 100644 index 0000000000000..15111767f5f29 --- /dev/null +++ b/winml/lib/Api.Experimental/LearningModelOperatorSet.cpp @@ -0,0 +1,94 @@ +#include "pch.h" +#include "LearningModelOperatorSet.h" +#include "LearningModelOperator.h" + +#include "..\Api\inc\ILotusValueProviderPrivate.h" + +namespace WINML_EXPERIMENTALP { + +LearningModelOperatorSet::LearningModelOperatorSet(winml_experimental::LearningModelBuilder builder) : + builder_(builder), + operators_(winrt::single_threaded_vector()) +{ +} + +winml_experimental::LearningModelBuilder LearningModelOperatorSet::Add(winml_experimental::LearningModelOperator const& op) +{ + auto operator_private = op.as(); + auto constant_input_map = operator_private->ConstantInputMapping(); + auto input_map = operator_private->InputMapping(); + auto output_map = operator_private->OutputMapping(); + auto attribute_map = operator_private->AttributeMap(); + + auto operator_name = _winml::Strings::UTF8FromHString(operator_private->Name()); + auto operator_type = _winml::Strings::UTF8FromHString(operator_private->Type()); + auto operator_domain = _winml::Strings::UTF8FromHString(operator_private->Domain()); + + std::vector operator_input_names(input_map.Size()); + std::vector actual_input_names(input_map.Size()); + std::vector raw_operator_input_names(input_map.Size()); + std::vector raw_actual_input_names(input_map.Size()); + int i = 0; + for (auto kvp : input_map) { + operator_input_names[i] = _winml::Strings::UTF8FromHString(kvp.Key()); + actual_input_names[i] = _winml::Strings::UTF8FromHString(kvp.Value()); + raw_operator_input_names[i] = operator_input_names[i].c_str(); + raw_actual_input_names[i] = actual_input_names[i].c_str(); + i++; + } + + std::vector operator_output_names(output_map.Size()); + std::vector actual_output_names(output_map.Size()); + std::vector raw_operator_output_names(output_map.Size()); + std::vector raw_actual_output_names(output_map.Size()); + i = 0; + for (auto kvp : output_map) { + operator_output_names[i] = _winml::Strings::UTF8FromHString(kvp.Key()); + actual_output_names[i] = _winml::Strings::UTF8FromHString(kvp.Value()); + raw_operator_output_names[i] = operator_output_names[i].c_str(); + raw_actual_output_names[i] = actual_output_names[i].c_str(); + i++; + } + + // Create the Binding Context to pass to the feature value + _winml::BindingContext context{ + _winml::BindingType::kInput, + builder_.as()->InertSession(), + nullptr, + nullptr, + {} // SubresourceId is set by callee + }; + + std::vector attribute_names(attribute_map.Size()); + std::vector raw_attribute_names(attribute_map.Size()); + std::vector> attribute_values(attribute_map.Size()); + std::vector<_winml::IValue*> raw_attribute_values(attribute_map.Size()); + i = 0; + for (auto kvp : attribute_map) { + attribute_names[i] = _winml::Strings::UTF8FromHString(kvp.Key()); + auto default_value_value_provider = kvp.Value().as<_winml::ILotusValueProviderPrivate>(); + default_value_value_provider->GetValue(context, attribute_values[i].put()); + + raw_attribute_names[i] = attribute_names[i].c_str(); + raw_attribute_values[i] = attribute_values[i].get(); + i++; + } + + auto builder = builder_.as(); + WINML_THROW_IF_FAILED(builder->UseModel()->AddOperator( + operator_type.c_str(), + operator_name.c_str(), + operator_domain.c_str(), + raw_operator_input_names.data(), raw_actual_input_names.data(), input_map.Size(), + raw_operator_output_names.data(), raw_actual_output_names.data(), output_map.Size(), + raw_attribute_names.data(), raw_attribute_values.data(), attribute_map.Size())); + + // Add constants + for (auto kvp : constant_input_map) { + builder_.Inputs().AddConstant(kvp.Key(), kvp.Value()); + } + + return builder_; +} + +} diff --git a/winml/lib/Api.Experimental/LearningModelOperatorSet.h b/winml/lib/Api.Experimental/LearningModelOperatorSet.h new file mode 100644 index 0000000000000..8e6d842fb0b31 --- /dev/null +++ b/winml/lib/Api.Experimental/LearningModelOperatorSet.h @@ -0,0 +1,17 @@ +#pragma once + +#include "LearningModelOperatorSet.g.h" + +namespace WINML_EXPERIMENTALP { + + struct LearningModelOperatorSet : LearningModelOperatorSetT + { + LearningModelOperatorSet(winml_experimental::LearningModelBuilder builder); + + winml_experimental::LearningModelBuilder Add(winml_experimental::LearningModelOperator const& op); + + private: + winml_experimental::LearningModelBuilder builder_; + wfc::IVector operators_; + }; +} diff --git a/winml/lib/Api.Experimental/LearningModelOutputs.cpp b/winml/lib/Api.Experimental/LearningModelOutputs.cpp new file mode 100644 index 0000000000000..8de00f58d374d --- /dev/null +++ b/winml/lib/Api.Experimental/LearningModelOutputs.cpp @@ -0,0 +1,25 @@ +#include "pch.h" +#include "LearningModelOutputs.h" +#include "LearningModelBuilder.h" +#include "TensorFeatureDescriptor.h" + +namespace WINML_EXPERIMENTALP +{ + +LearningModelOutputs::LearningModelOutputs(winml_experimental::LearningModelBuilder builder) : + builder_(builder), + output_descriptors_(winrt::single_threaded_vector()) { +} + +winml_experimental::LearningModelBuilder LearningModelOutputs::Add(winml::ILearningModelFeatureDescriptor const& output) +{ + // Perform model update inside the builder + auto model = builder_.as()->UseModel(); + auto descriptor_provider = output.as<_winml::IDescriptorInfoProvider>(); + auto name = _winml::Strings::UTF8FromHString(output.Name()); + model->AddModelOutput(name.c_str(), descriptor_provider.get()); + output_descriptors_.Append(output); + return builder_; +} + +} // namespace WINML_EXPERIMENTALP diff --git a/winml/lib/Api.Experimental/LearningModelOutputs.h b/winml/lib/Api.Experimental/LearningModelOutputs.h new file mode 100644 index 0000000000000..c0d99a4412963 --- /dev/null +++ b/winml/lib/Api.Experimental/LearningModelOutputs.h @@ -0,0 +1,19 @@ +#pragma once + +#include "LearningModelOutputs.g.h" + +namespace WINML_EXPERIMENTALP { + +struct LearningModelOutputs : LearningModelOutputsT +{ + LearningModelOutputs(winml_experimental::LearningModelBuilder builder); + + winml_experimental::LearningModelBuilder Add(winml::ILearningModelFeatureDescriptor const& output); + + private: + wfc::IVector output_descriptors_; + winml_experimental::LearningModelBuilder builder_; + +}; + +} // namespace WINML_EXPERIMENTALP \ No newline at end of file diff --git a/winml/lib/Api.Experimental/LearningModelSessionExperimental.cpp b/winml/lib/Api.Experimental/LearningModelSessionExperimental.cpp index 4b3150f1f96ff..96e4ad7020ad8 100644 --- a/winml/lib/Api.Experimental/LearningModelSessionExperimental.cpp +++ b/winml/lib/Api.Experimental/LearningModelSessionExperimental.cpp @@ -5,12 +5,10 @@ namespace WINML_EXPERIMENTALP { -LearningModelSessionExperimental::LearningModelSessionExperimental(const winml::LearningModelSession& session) : _session(session) { - int i = 0; - i++; +LearningModelSessionExperimental::LearningModelSessionExperimental(const winml::LearningModelSession& session) : + _session(session) { } - WINML_EXPERIMENTAL::LearningModelSessionOptionsExperimental LearningModelSessionExperimental::Options() { return winrt::make(_session); } diff --git a/winml/lib/Api.Ort/OnnxruntimeDescriptorConverter.cpp b/winml/lib/Api.Ort/OnnxruntimeDescriptorConverter.cpp index 58de84c6c1982..9e728ee9c6c3b 100644 --- a/winml/lib/Api.Ort/OnnxruntimeDescriptorConverter.cpp +++ b/winml/lib/Api.Ort/OnnxruntimeDescriptorConverter.cpp @@ -614,10 +614,11 @@ OnnxruntimeDescriptorConverter::OnnxruntimeDescriptorConverter( const std::unordered_map& metadata) : engine_factory_(engine_factory), metadata_(metadata) {} wfc::IVector -OnnxruntimeDescriptorConverter::ConvertToLearningModelDescriptors(const std::vector& descriptors) { +OnnxruntimeDescriptorConverter::ConvertToLearningModelDescriptors(const OnnxruntimeValueInfoWrapper* descriptors, size_t num_descriptors) { auto features = winrt::single_threaded_vector(); - for (const auto& descriptor : descriptors) { + for (size_t i = 0; i < num_descriptors; i++) { + const auto& descriptor = descriptors[i]; auto learning_model_descriptor = _winml::CreateFeatureDescriptor(engine_factory_.Get(), &descriptor, metadata_); features.Append(learning_model_descriptor); } diff --git a/winml/lib/Api.Ort/OnnxruntimeDescriptorConverter.h b/winml/lib/Api.Ort/OnnxruntimeDescriptorConverter.h index dda03d431f7c7..f16749b5ea88d 100644 --- a/winml/lib/Api.Ort/OnnxruntimeDescriptorConverter.h +++ b/winml/lib/Api.Ort/OnnxruntimeDescriptorConverter.h @@ -23,7 +23,7 @@ struct OnnxruntimeDescriptorConverter { const std::unordered_map& model_metadata); wfc::IVector - ConvertToLearningModelDescriptors(const std::vector& descriptors); + ConvertToLearningModelDescriptors(const OnnxruntimeValueInfoWrapper* descriptors, size_t num_descriptors); private: Microsoft::WRL::ComPtr engine_factory_; diff --git a/winml/lib/Api.Ort/OnnxruntimeEngine.cpp b/winml/lib/Api.Ort/OnnxruntimeEngine.cpp index 2db1bb6053e64..38182cf5bf3f3 100644 --- a/winml/lib/Api.Ort/OnnxruntimeEngine.cpp +++ b/winml/lib/Api.Ort/OnnxruntimeEngine.cpp @@ -1328,6 +1328,18 @@ STDMETHODIMP OnnxruntimeEngineFactory::CreateModel(_In_ void* data, _In_ size_t return S_OK; } +STDMETHODIMP OnnxruntimeEngineFactory::CreateEmptyModel(int64_t opset, _Outptr_ _winml::IModel** out) { + RETURN_IF_FAILED(EnsureEnvironment()); + OrtModel* ort_model = nullptr; + if (auto status = winml_adapter_api_->CreateModel(opset, &ort_model)) { + return E_INVALIDARG; + } + + auto model = UniqueOrtModel(ort_model, winml_adapter_api_->ReleaseModel); + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(out, this, std::move(model))); + return S_OK; +} + STDMETHODIMP OnnxruntimeEngineFactory::CreateEngineBuilder(_Outptr_ _winml::IEngineBuilder** out) { RETURN_IF_FAILED(EnsureEnvironment()); Microsoft::WRL::ComPtr onnxruntime_engine_builder; @@ -1368,3 +1380,60 @@ STDAPI CreateOnnxruntimeEngineFactory(_Out_ _winml::IEngineFactory** engine_fact RETURN_IF_FAILED(onnxruntime_engine_factory.CopyTo(engine_factory)); return S_OK; } +struct OrtDescriptorInfo : public Microsoft::WRL::RuntimeClass< + Microsoft::WRL::RuntimeClassFlags, + IDescriptorInfo, + IOrtTypeInfoProvider> { + OrtDescriptorInfo() : info_(nullptr, nullptr) {} + + HRESULT RuntimeClassInitialize(UniqueOrtTypeInfo info) { + info_ = std::move(info); + return S_OK; + } + + STDMETHOD(GetTypeInfo)(OrtTypeInfo** info) override { + *info = info_.get(); + return S_OK; + } + + OrtTypeInfo* UseOrtTypeInfo() { + return info_.get(); + } + + private: + UniqueOrtTypeInfo info_; +}; + +HRESULT OnnxruntimeEngineFactory::CreateTensorDescriptorInfo(winml::TensorKind kind, int64_t* dims, + size_t num_dims, IDescriptorInfo** tensor_info) { + OrtTypeInfo* tensor_type_info = nullptr; + winml_adapter_api_->CreateTensorTypeInfo(dims, num_dims, ONNXTensorElementDataTypeFromTensorKind(kind), &tensor_type_info); + UniqueOrtTypeInfo info(tensor_type_info, ort_api_->ReleaseTypeInfo); + + Microsoft::WRL::ComPtr descriptor_info; + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(&descriptor_info, std::move(info))); + RETURN_IF_FAILED(descriptor_info.CopyTo(tensor_info)); + return S_OK; +} + +HRESULT OnnxruntimeEngineFactory::CreateSequenceDescriptorInfo(IDescriptorInfo** seq_info) { + OrtTypeInfo* sequence_type_info = nullptr; + winml_adapter_api_->CreateSequenceTypeInfo(&sequence_type_info); + UniqueOrtTypeInfo info(sequence_type_info, ort_api_->ReleaseTypeInfo); + + Microsoft::WRL::ComPtr descriptor_info; + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(&descriptor_info, std::move(info))); + RETURN_IF_FAILED(descriptor_info.CopyTo(seq_info)); + return S_OK; +} + +HRESULT OnnxruntimeEngineFactory::CreateMapDescriptorInfo(IDescriptorInfo** desc_info) { + OrtTypeInfo* map_type_info = nullptr; + winml_adapter_api_->CreateMapTypeInfo(&map_type_info); + UniqueOrtTypeInfo info(map_type_info, ort_api_->ReleaseTypeInfo); + + Microsoft::WRL::ComPtr descriptor_info; + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(&descriptor_info, std::move(info))); + RETURN_IF_FAILED(descriptor_info.CopyTo(desc_info)); + return S_OK; +} \ No newline at end of file diff --git a/winml/lib/Api.Ort/OnnxruntimeEngine.h b/winml/lib/Api.Ort/OnnxruntimeEngine.h index b14b7e3820d2c..d3931f8c017bf 100644 --- a/winml/lib/Api.Ort/OnnxruntimeEngine.h +++ b/winml/lib/Api.Ort/OnnxruntimeEngine.h @@ -139,6 +139,8 @@ class OnnxruntimeEngineFactory : public Microsoft::WRL::RuntimeClass< (_In_ const char* model_path, _In_ size_t len, _Outptr_ IModel** out) override; STDMETHOD(CreateModel) (_In_ void* data, _In_ size_t size, _Outptr_ IModel** out) override; + STDMETHOD(CreateEmptyModel) + (_In_ int64_t opset, _Outptr_ IModel** out) override; STDMETHOD(CreateEngineBuilder) (_Outptr_ IEngineBuilder** engine_builder) override; STDMETHOD(EnableDebugOutput) @@ -151,7 +153,16 @@ class OnnxruntimeEngineFactory : public Microsoft::WRL::RuntimeClass< HRESULT EnsureEnvironment(); HRESULT GetOrtEnvironment(_Out_ OrtEnv** ort_env); - private: + STDMETHOD(CreateTensorDescriptorInfo) + (_In_ winml::TensorKind kind, _In_ int64_t* dims, _In_ size_t num_dims, _Out_ IDescriptorInfo** info) override; + + STDMETHOD(CreateSequenceDescriptorInfo) + (_Out_ IDescriptorInfo** info) override; + + STDMETHOD(CreateMapDescriptorInfo) + (_Out_ IDescriptorInfo** info) override; + +private: const OrtApi* ort_api_ = nullptr; const WinmlAdapterApi* winml_adapter_api_ = nullptr; std::shared_ptr environment_; diff --git a/winml/lib/Api.Ort/OnnxruntimeEngineBuilder.h b/winml/lib/Api.Ort/OnnxruntimeEngineBuilder.h index d0cb7e2c9421e..0e02548701b2c 100644 --- a/winml/lib/Api.Ort/OnnxruntimeEngineBuilder.h +++ b/winml/lib/Api.Ort/OnnxruntimeEngineBuilder.h @@ -42,7 +42,7 @@ class OnnxruntimeEngineBuilder : public Microsoft::WRL::RuntimeClass< bool metacommands_enabled_ = true; std::optional batch_size_override_; wfc::IMapView named_dimension_overrides_; - uint32_t intra_op_num_threads_override_; + uint32_t intra_op_num_threads_override_ = 0; }; } // namespace _winml \ No newline at end of file diff --git a/winml/lib/Api.Ort/OnnxruntimeModel.cpp b/winml/lib/Api.Ort/OnnxruntimeModel.cpp index a1354498a33ed..20287befc1b30 100644 --- a/winml/lib/Api.Ort/OnnxruntimeModel.cpp +++ b/winml/lib/Api.Ort/OnnxruntimeModel.cpp @@ -81,7 +81,7 @@ HRESULT ModelInfo::RuntimeClassInitialize(_In_ OnnxruntimeEngineFactory* engine_ // Create inputs std::vector inputs; RETURN_IF_FAILED(CreateFeatureDescriptors(engine_factory, &input_helpers, ort_model, inputs)); - input_features_ = converter.ConvertToLearningModelDescriptors(inputs); + input_features_ = converter.ConvertToLearningModelDescriptors(inputs.data(), inputs.size()); // Create outputs static const winml_adapter_api_model_feature_helper output_helpers = { @@ -92,7 +92,7 @@ HRESULT ModelInfo::RuntimeClassInitialize(_In_ OnnxruntimeEngineFactory* engine_ std::vector outputs; RETURN_IF_FAILED(CreateFeatureDescriptors(engine_factory, &output_helpers, ort_model, outputs)); - output_features_ = converter.ConvertToLearningModelDescriptors(outputs); + output_features_ = converter.ConvertToLearningModelDescriptors(outputs.data(), outputs.size()); const char* out; size_t len; @@ -115,7 +115,6 @@ HRESULT ModelInfo::RuntimeClassInitialize(_In_ OnnxruntimeEngineFactory* engine_ RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->ModelGetVersion(ort_model, &version_), engine_factory->UseOrtApi()); - return S_OK; } @@ -235,7 +234,173 @@ STDMETHODIMP OnnruntimeModel::CloneModel(IModel** copy) { return S_OK; } +STDMETHODIMP OnnruntimeModel::SaveModel(_In_ const wchar_t* const file_name, _In_ unsigned size) { + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SaveModel(ort_model_.get(), file_name, size), + engine_factory_->UseOrtApi()); + return S_OK; +} + STDMETHODIMP OnnruntimeModel::DetachOrtModel(OrtModel** model) { *model = ort_model_.release(); return S_OK; } + +HRESULT GetValue(const char* key, const char* const* keys, const char* const* values, + size_t num_values_in_dictionary, const char** value) { + auto found_it = + std::find_if(keys, keys + num_values_in_dictionary, [key](auto& key_name) { + return _stricmp(key, key_name) == 0; + }); + if (found_it == (keys + num_values_in_dictionary)) { + return S_FALSE; + } + *value = values[std::distance(keys, found_it)]; + return S_OK; +} + +STDMETHODIMP OnnruntimeModel::AddOperator( + _In_ const char* const op_type, _In_ const char* const op_name, _In_ const char* const op_domain, + _In_ const char* const* op_input_names, _In_ const char* const* actual_input_names, size_t num_inputs, + _In_ const char* const* op_output_names, _In_ const char* const* actual_output_names, size_t num_outputs, + _In_ const char* const* op_attribute_names, _In_ IValue** attribute_values, size_t num_attributes) { + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + auto ort_api = engine_factory_->UseOrtApi(); + + int32_t onnx_opset_version; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->ModelGetOpsetVersion(ort_model_.get(), op_domain, &onnx_opset_version), + ort_api); + size_t input_count; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->OperatorGetNumInputs(op_type, onnx_opset_version, op_domain, &input_count), + ort_api); + + size_t output_count; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->OperatorGetNumOutputs(op_type, onnx_opset_version, op_domain, &output_count), + ort_api); + + std::vector input_names(input_count); + for (size_t i = 0; i < input_count; i++) { + const char* name; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->OperatorGetInputName(op_type, onnx_opset_version, op_domain, i, &name), + ort_api); + + const char* actual_name; + if (S_OK == GetValue(name, op_input_names, actual_input_names, num_inputs, &actual_name)) + { + input_names[i] = actual_name; + } + } + + std::vector output_names(output_count); + for (size_t i = 0; i < output_count; i++) { + const char* name; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->OperatorGetOutputName(op_type, onnx_opset_version, op_domain, i, &name), + ort_api); + const char* actual_name = nullptr; + if (S_OK == GetValue(name, op_output_names, actual_output_names, num_outputs, &actual_name)) { + output_names[i] = actual_name; + } + } + + std::vector attributes; + for (size_t i = 0; i < num_attributes; i++) { + attributes.push_back(static_cast(*(attribute_values + i))->UseOrtValue()); + } + + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->ModelAddOperator( + ort_model_.get(), op_type, op_name, onnx_opset_version, op_domain, input_names.data(), input_count, output_names.data(), output_count, op_attribute_names, attributes.data(), num_attributes), + engine_factory_->UseOrtApi()); + return S_OK; +} + +static ONNXTensorElementDataType +ONNXTensorElementDataTypeFromTensorKind(winml::TensorKind kind) { + switch (kind) { + case winml::TensorKind::Boolean: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL; + } + case winml::TensorKind::String: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; + } + case winml::TensorKind::Float16: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; + } + case winml::TensorKind::Float: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + } + case winml::TensorKind::Double: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; + } + case winml::TensorKind::Int8: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; + } + case winml::TensorKind::Int16: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16; + } + case winml::TensorKind::Int32: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; + } + case winml::TensorKind::Int64: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; + } + case winml::TensorKind::UInt8: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; + } + case winml::TensorKind::UInt16: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16; + } + case winml::TensorKind::UInt32: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32; + } + case winml::TensorKind::UInt64: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64; + } + case winml::TensorKind::Complex64: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64; + } + case winml::TensorKind::Complex128: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128; + } + default: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + } + } +} + +STDMETHODIMP OnnruntimeModel::AddModelInput(_In_ const char* const name, _In_ IDescriptorInfoProvider* descriptor_provider, bool is_constant, IValue* constant_value) { + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + auto ort_api = engine_factory_->UseOrtApi(); + + winrt::com_ptr<_winml::IDescriptorInfo> descriptor_info; + descriptor_provider->GetDescriptorInfo(engine_factory_.Get(), descriptor_info.put()); + + auto ort_type_info_provider = descriptor_info.as<_winml::IOrtTypeInfoProvider>(); + OrtTypeInfo* type_info; + ort_type_info_provider->GetTypeInfo(&type_info); + + if (is_constant) { + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->ModelAddConstantInput(ort_model_.get(), name, type_info, static_cast(constant_value)->UseOrtValue()), + ort_api); + } else { + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->ModelAddInput(ort_model_.get(), name, type_info), + ort_api); + } + + return S_OK; +} + +STDMETHODIMP OnnruntimeModel::AddModelOutput(_In_ const char* const name, _In_ IDescriptorInfoProvider* descriptor_provider) { + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + auto ort_api = engine_factory_->UseOrtApi(); + + winrt::com_ptr<_winml::IDescriptorInfo> descriptor_info; + descriptor_provider->GetDescriptorInfo(engine_factory_.Get(), descriptor_info.put()); + + auto ort_type_info_provider = descriptor_info.as<_winml::IOrtTypeInfoProvider>(); + OrtTypeInfo* type_info; + ort_type_info_provider->GetTypeInfo(&type_info); + + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->ModelAddOutput(ort_model_.get(), name, type_info), ort_api); + return S_OK; +} \ No newline at end of file diff --git a/winml/lib/Api.Ort/OnnxruntimeModel.h b/winml/lib/Api.Ort/OnnxruntimeModel.h index 1c34651bb06e3..ebca5e5d1ec5f 100644 --- a/winml/lib/Api.Ort/OnnxruntimeModel.h +++ b/winml/lib/Api.Ort/OnnxruntimeModel.h @@ -44,7 +44,7 @@ class ModelInfo : public Microsoft::WRL::RuntimeClass< std::string name_; std::string domain_; std::string description_; - int64_t version_; + int64_t version_ = 0; std::unordered_map model_metadata_; wfc::IVector input_features_; wfc::IVector output_features_; @@ -65,9 +65,24 @@ class OnnruntimeModel : public Microsoft::WRL::RuntimeClass< (); STDMETHOD(CloneModel) (IModel** copy); + STDMETHOD(SaveModel) + (_In_ const wchar_t* const file_name, + _In_ unsigned size); STDMETHOD(DetachOrtModel) (OrtModel** model); + STDMETHOD(AddOperator) + (_In_ const char* const op_type, _In_ const char* const op_name, _In_ const char* const op_domain, + _In_ const char* const* op_input_names, _In_ const char* const* actual_input_names, size_t num_inputs, + _In_ const char* const* op_output_names, _In_ const char* const* actual_output_names, size_t num_outputs, + _In_ const char* const* op_attribute_names, _In_ IValue** constant_value, size_t num_attributes); + + STDMETHOD(AddModelInput) + (_In_ const char* const name, _In_ IDescriptorInfoProvider* descriptor_provider, bool is_constant, IValue* default_value); + + STDMETHOD(AddModelOutput) + (_In_ const char* const name, _In_ IDescriptorInfoProvider* descriptor_provider); + private: UniqueOrtModel ort_model_; diff --git a/winml/lib/Api/ImageFeatureDescriptor.cpp b/winml/lib/Api/ImageFeatureDescriptor.cpp index 6d9408f5074ca..8dc5a0d6695f5 100644 --- a/winml/lib/Api/ImageFeatureDescriptor.cpp +++ b/winml/lib/Api/ImageFeatureDescriptor.cpp @@ -120,4 +120,14 @@ ImageColorSpaceGamma ImageFeatureDescriptor::GetColorSpaceGamma() { return color_space_gamma_; } + +HRESULT +ImageFeatureDescriptor::GetDescriptorInfo( + _winml::IEngineFactory* engine_factory, + _winml::IDescriptorInfo** info) { + // TODO: Need to add denotations here + engine_factory->CreateTensorDescriptorInfo(tensor_kind_, shape_.data(), shape_.size(), info); + return S_OK; +} + } // namespace WINMLP diff --git a/winml/lib/Api/ImageFeatureDescriptor.h b/winml/lib/Api/ImageFeatureDescriptor.h index 97cbdb2a1adb9..744702149cbd1 100644 --- a/winml/lib/Api/ImageFeatureDescriptor.h +++ b/winml/lib/Api/ImageFeatureDescriptor.h @@ -4,6 +4,7 @@ #pragma once #include "ImageFeatureDescriptor.g.h" +#include "iengine.h" namespace WINMLP { @@ -14,7 +15,8 @@ enum class ImageColorSpaceGamma { struct ImageFeatureDescriptor : ImageFeatureDescriptorT< ImageFeatureDescriptor, - ILearningModelFeatureDescriptorNative> { + ILearningModelFeatureDescriptorNative, + _winml::IDescriptorInfoProvider> { ImageFeatureDescriptor() = delete; ImageFeatureDescriptor( const char* name, @@ -75,6 +77,11 @@ struct ImageFeatureDescriptor : ImageFeatureDescriptorT< const wchar_t** description, uint32_t* cchDescription) override; + STDMETHOD(GetDescriptorInfo) + ( + _winml::IEngineFactory* engine_factory, + _winml::IDescriptorInfo** info) override; + private: winrt::hstring name_; winrt::hstring description_; diff --git a/winml/lib/Api/LearningModel.cpp b/winml/lib/Api/LearningModel.cpp index 68432cec6bf51..d7fd195b2a226 100644 --- a/winml/lib/Api/LearningModel.cpp +++ b/winml/lib/Api/LearningModel.cpp @@ -22,6 +22,17 @@ LearningModel::LearningModel( } WINML_CATCH_ALL +LearningModel::LearningModel( + _winml::IEngineFactory* engine_factory, + _winml::IModel* model, + const winml::ILearningModelOperatorProvider operator_provider) try : + operator_provider_(operator_provider) { + engine_factory_.copy_from(engine_factory); + model_.copy_from(model); + WINML_THROW_IF_FAILED(model_->GetModelInfo(model_info_.put())); +} +WINML_CATCH_ALL + LearningModel::LearningModel( const std::string& path, const winml::ILearningModelOperatorProvider operator_provider) try : operator_provider_(operator_provider) { diff --git a/winml/lib/Api/LearningModel.h b/winml/lib/Api/LearningModel.h index a2dac15122bdd..d7d696480caf9 100644 --- a/winml/lib/Api/LearningModel.h +++ b/winml/lib/Api/LearningModel.h @@ -29,6 +29,11 @@ struct LearningModel : LearningModelT { const std::string& path, const winml::ILearningModelOperatorProvider operator_provider); + LearningModel( + _winml::IEngineFactory* engine_factory, + _winml::IModel* model, + const winml::ILearningModelOperatorProvider operator_provider); + /* LearningModel properties (MachineLearningContract 1). */ hstring Author(); diff --git a/winml/lib/Api/LearningModelSession.cpp b/winml/lib/Api/LearningModelSession.cpp index 949e5683f7eab..7f69fa74021c8 100644 --- a/winml/lib/Api/LearningModelSession.cpp +++ b/winml/lib/Api/LearningModelSession.cpp @@ -27,6 +27,15 @@ static const GUID WINML_PIX_EVAL_CAPTURABLE_WORK_GUID = __uuidof(guid_details::W namespace WINMLP { +LearningModelSession::LearningModelSession(_winml::IEngine* engine) : model_(nullptr), + device_(LearningModelDeviceKind::Cpu), + session_options_(nullptr), + operator_registry_(nullptr, nullptr) +{ + engine_.copy_from(engine); +} + + LearningModelSession::LearningModelSession( winml::LearningModel const& model) try : LearningModelSession(model, make(LearningModelDeviceKind::Default)) {} @@ -432,4 +441,9 @@ STDMETHODIMP LearningModelSession::GetIntraOpNumThreads(uint32_t* numThreads) { return engine_->GetNumberOfIntraOpThreads(numThreads); } + +winml::LearningModelSession LearningModelSession::CreateInertSession(_winml::IEngine* engine) { + return winrt::make(engine); +} + } // namespace WINMLP \ No newline at end of file diff --git a/winml/lib/Api/LearningModelSession.h b/winml/lib/Api/LearningModelSession.h index 31d77d705514a..5949d713bd6a3 100644 --- a/winml/lib/Api/LearningModelSession.h +++ b/winml/lib/Api/LearningModelSession.h @@ -15,7 +15,7 @@ namespace WINMLP { struct LearningModelSession : LearningModelSessionT { /* LearningModelSession constructors (MachineLearningContract 1). */ - LearningModelSession() = delete; + LearningModelSession(_winml::IEngine* engine); LearningModelSession( winml::LearningModel const& model); @@ -83,6 +83,9 @@ struct LearningModelSession : LearningModelSessionT(description_.size()); return S_OK; } -} // namespace WINMLP + +HRESULT +MapFeatureDescriptor::GetDescriptorInfo( + _winml::IEngineFactory* engine_factory, + _winml::IDescriptorInfo** info) { + engine_factory->CreateMapDescriptorInfo(info); + return S_OK; +} + +} // namespace WINMLP \ No newline at end of file diff --git a/winml/lib/Api/MapFeatureDescriptor.h b/winml/lib/Api/MapFeatureDescriptor.h index e22d70f9040c8..c214eca55740f 100644 --- a/winml/lib/Api/MapFeatureDescriptor.h +++ b/winml/lib/Api/MapFeatureDescriptor.h @@ -4,11 +4,13 @@ #pragma once #include "MapFeatureDescriptor.g.h" +#include "iengine.h" namespace WINMLP { struct MapFeatureDescriptor : MapFeatureDescriptorT< MapFeatureDescriptor, - ILearningModelFeatureDescriptorNative> { + ILearningModelFeatureDescriptorNative, + _winml::IDescriptorInfoProvider> { MapFeatureDescriptor() = delete; MapFeatureDescriptor( @@ -47,6 +49,11 @@ struct MapFeatureDescriptor : MapFeatureDescriptorT< ( const wchar_t** description, uint32_t* cchDescription) override; + + STDMETHOD(GetDescriptorInfo) + ( + _winml::IEngineFactory* engine_factory, + _winml::IDescriptorInfo** info) override; private: winrt::hstring name_; diff --git a/winml/lib/Api/SequenceFeatureDescriptor.cpp b/winml/lib/Api/SequenceFeatureDescriptor.cpp index 30be69502d61c..8b2050859ee7b 100644 --- a/winml/lib/Api/SequenceFeatureDescriptor.cpp +++ b/winml/lib/Api/SequenceFeatureDescriptor.cpp @@ -62,4 +62,14 @@ SequenceFeatureDescriptor::GetDescription( *cchDescription = static_cast(description_.size()); return S_OK; } -} // namespace WINMLP + +HRESULT +SequenceFeatureDescriptor::GetDescriptorInfo( + _winml::IEngineFactory* engine_factory, + _winml::IDescriptorInfo** info) { + engine_factory->CreateSequenceDescriptorInfo(info); + return S_OK; +}; + + +} // namespace WINMLP \ No newline at end of file diff --git a/winml/lib/Api/SequenceFeatureDescriptor.h b/winml/lib/Api/SequenceFeatureDescriptor.h index 4627b8de0e118..cc4ad8c52f1de 100644 --- a/winml/lib/Api/SequenceFeatureDescriptor.h +++ b/winml/lib/Api/SequenceFeatureDescriptor.h @@ -4,11 +4,13 @@ #pragma once #include "SequenceFeatureDescriptor.g.h" +#include "iengine.h" namespace WINMLP { struct SequenceFeatureDescriptor : SequenceFeatureDescriptorT< SequenceFeatureDescriptor, - ILearningModelFeatureDescriptorNative> { + ILearningModelFeatureDescriptorNative, + _winml::IDescriptorInfoProvider> { SequenceFeatureDescriptor() = delete; SequenceFeatureDescriptor( const char* name, @@ -42,6 +44,10 @@ struct SequenceFeatureDescriptor : SequenceFeatureDescriptorT< const wchar_t** description, uint32_t* cchDescription) override; + STDMETHOD(GetDescriptorInfo) + ( + _winml::IEngineFactory* engine_factory, + _winml::IDescriptorInfo** info) override; private: winrt::hstring name_; winrt::hstring description_; diff --git a/winml/lib/Api/TensorFeatureDescriptor.cpp b/winml/lib/Api/TensorFeatureDescriptor.cpp index ee91887f25280..dd9a7f6c9a4d5 100644 --- a/winml/lib/Api/TensorFeatureDescriptor.cpp +++ b/winml/lib/Api/TensorFeatureDescriptor.cpp @@ -22,6 +22,18 @@ TensorFeatureDescriptor::TensorFeatureDescriptor( has_unsupported_image_metadata_(has_unsupported_image_metadata) { } +TensorFeatureDescriptor::TensorFeatureDescriptor( + hstring const& name, + hstring const& description, + winml::TensorKind const& kind, + array_view shape) : name_(name), + description_(description), + tensor_kind_(kind), + shape_(shape.begin(), shape.end()), + is_required_(true), + has_unsupported_image_metadata_(false) { +} + winml::TensorKind TensorFeatureDescriptor::TensorKind() try { return tensor_kind_; @@ -83,4 +95,12 @@ TensorFeatureDescriptor::GetDescription( *cchDescription = static_cast(description_.size()); return S_OK; } -} // namespace WINMLP + +HRESULT TensorFeatureDescriptor::GetDescriptorInfo( + _winml::IEngineFactory* engine_factory, + _winml::IDescriptorInfo** info){ + engine_factory->CreateTensorDescriptorInfo(tensor_kind_, shape_.data(), shape_.size(), info); + return S_OK; +}; + +} // namespace WINMLP \ No newline at end of file diff --git a/winml/lib/Api/TensorFeatureDescriptor.h b/winml/lib/Api/TensorFeatureDescriptor.h index eb83d60bc7466..4357ffe3dbc6a 100644 --- a/winml/lib/Api/TensorFeatureDescriptor.h +++ b/winml/lib/Api/TensorFeatureDescriptor.h @@ -4,11 +4,13 @@ #pragma once #include "TensorFeatureDescriptor.g.h" +#include "iengine.h" namespace WINMLP { struct TensorFeatureDescriptor : TensorFeatureDescriptorT< TensorFeatureDescriptor, - ILearningModelFeatureDescriptorNative> { + ILearningModelFeatureDescriptorNative, + _winml::IDescriptorInfoProvider> { TensorFeatureDescriptor() = delete; TensorFeatureDescriptor( const char* name, @@ -17,6 +19,12 @@ struct TensorFeatureDescriptor : TensorFeatureDescriptorT< const std::vector& shape, bool is_required, bool has_unsuppored_image_metadata); + + TensorFeatureDescriptor( + hstring const& name, + hstring const& description, + winml::TensorKind const& kind, + array_view shape); // ITensorDescriptor winml::TensorKind @@ -51,6 +59,12 @@ struct TensorFeatureDescriptor : TensorFeatureDescriptorT< const wchar_t** description, uint32_t* cchDescription) override; + + STDMETHOD(GetDescriptorInfo) + ( + _winml::IEngineFactory* engine_factory, + _winml::IDescriptorInfo** info) override; + private: winrt::hstring name_; winrt::hstring description_; @@ -59,4 +73,4 @@ struct TensorFeatureDescriptor : TensorFeatureDescriptorT< bool is_required_; bool has_unsupported_image_metadata_; }; -} // namespace WINMLP \ No newline at end of file +} // WINMLP \ No newline at end of file diff --git a/winml/lib/Api/impl/FeatureCompatibility.h b/winml/lib/Api/impl/FeatureCompatibility.h index 2966f854b2b78..1e8f818669f91 100644 --- a/winml/lib/Api/impl/FeatureCompatibility.h +++ b/winml/lib/Api/impl/FeatureCompatibility.h @@ -15,7 +15,7 @@ namespace _winml { namespace error_strings { // This must be kept in sync with the TensorKind enum in Windows.AI.MachineLearning.idl -const char* SzTensorKind[] = +__declspec(selectany) const char* SzTensorKind[] = { "Undefined", "Float", diff --git a/winml/lib/Api/impl/IData.h b/winml/lib/Api/impl/IData.h index d649f4d5e1cd1..0615e5fe50308 100644 --- a/winml/lib/Api/impl/IData.h +++ b/winml/lib/Api/impl/IData.h @@ -5,9 +5,6 @@ #include "IEngine.h" -// ILotusValueProviderPrivate exposes a private Lotus interface to the engine so that it can retrieve tensor -// resources stored in winrt structures. - namespace _winml { struct idata { diff --git a/winml/lib/Api/inc/ILotusValueProviderPrivate.h b/winml/lib/Api/inc/ILotusValueProviderPrivate.h index 0a487be263719..9bd855e306911 100644 --- a/winml/lib/Api/inc/ILotusValueProviderPrivate.h +++ b/winml/lib/Api/inc/ILotusValueProviderPrivate.h @@ -29,5 +29,4 @@ struct __declspec(uuid("27e2f437-0112-4693-849e-e04323a620fb")) __declspec(novta virtual HRESULT __stdcall UpdateSourceResourceData(BindingContext& binding_context, _winml::IValue* value) = 0; virtual HRESULT __stdcall AbiRepresentation(wf::IInspectable& abi_representation) = 0; }; - } // namespace _winml \ No newline at end of file diff --git a/winml/lib/Common/inc/iengine.h b/winml/lib/Common/inc/iengine.h index 1ff4615a7dfb5..c5cfeb41bd4e1 100644 --- a/winml/lib/Common/inc/iengine.h +++ b/winml/lib/Common/inc/iengine.h @@ -3,8 +3,61 @@ #pragma once +struct OrtTypeInfo; + namespace _winml { +interface IEngineFactory; + +using Resource = std::unique_ptr>; +MIDL_INTERFACE("31f39226-cfe8-4758-af38-3d01b2a33ee1") +IValue : IUnknown { + STDMETHOD(IsEmpty) + (bool* out) PURE; + + STDMETHOD(IsCpu) + (bool* out) PURE; + + STDMETHOD(GetResource) + (_winml::Resource & resource) PURE; + + STDMETHOD(IsTensor) + (bool* out) PURE; + + STDMETHOD(IsOfTensorType) + (winml::TensorKind kind, bool* out) PURE; + + STDMETHOD(GetTensorShape) + (std::vector & shape_vector) PURE; + + STDMETHOD(IsOfMapType) + (winml::TensorKind key_kind, winml::TensorKind value_kind, bool* out) PURE; + + STDMETHOD(IsOfVectorMapType) + (winml::TensorKind key_kind, winml::TensorKind value_kind, bool* out) PURE; + + STDMETHOD(IsOfVectorTensorType) + (winml::TensorKind kind, bool* out) PURE; +}; + +MIDL_INTERFACE("4637dfcb-fc19-45c3-a632-c84942d0cf8e") +IOrtTypeInfoProvider : IUnknown { + STDMETHOD(GetTypeInfo) + (OrtTypeInfo * *info) PURE; +}; + +MIDL_INTERFACE("fe94665f-76cb-42a2-ab21-a06ae1c7f1ae") +IDescriptorInfo : IUnknown{ + +}; + +MIDL_INTERFACE("e3feaec4-eb09-4b82-973c-781f1c230842") +IDescriptorInfoProvider : IUnknown{ + STDMETHOD(GetDescriptorInfo) + (IEngineFactory* engine_factory, IDescriptorInfo * *info) PURE; +}; + + MIDL_INTERFACE("eaae30b5-7381-432d-9730-322136b02371") IModelInfo : IUnknown { STDMETHOD(GetAuthor) @@ -16,7 +69,6 @@ IModelInfo : IUnknown { STDMETHOD(GetDomain) (const char** out, size_t* len) PURE; - STDMETHOD(GetDescription) (const char** out, size_t* len) PURE; @@ -43,43 +95,28 @@ IModel : IUnknown { STDMETHOD(CloneModel) (IModel **copy) PURE; -}; - -using Resource = std::unique_ptr>; -MIDL_INTERFACE("31f39226-cfe8-4758-af38-3d01b2a33ee1") -IValue : IUnknown { - STDMETHOD(IsEmpty) - (bool* out) PURE; - - STDMETHOD(IsCpu) - (bool* out) PURE; - - STDMETHOD(GetResource) - (_winml::Resource & resource) PURE; - STDMETHOD(IsTensor) - (bool* out) PURE; - - STDMETHOD(IsOfTensorType) - (winml::TensorKind kind, bool* out) PURE; + STDMETHOD(SaveModel) + (_In_ const wchar_t* const file_name, + _In_ unsigned size) PURE; - STDMETHOD(GetTensorShape) - (std::vector & shape_vector) PURE; - - STDMETHOD(IsOfMapType) - (winml::TensorKind key_kind, winml::TensorKind value_kind, bool* out) PURE; + STDMETHOD(AddOperator) + (_In_ const char* const op_type, _In_ const char* const op_name, _In_ const char* const op_domain, + _In_ const char* const* op_input_names, _In_ const char* const* actual_input_names, size_t num_inputs, + _In_ const char* const* op_output_names, _In_ const char* const* actual_output_names, size_t num_outputs, + _In_ const char* const* op_attribute_names, _In_ IValue** constant_value, size_t num_attributes) PURE; - STDMETHOD(IsOfVectorMapType) - (winml::TensorKind key_kind, winml::TensorKind value_kind, bool* out) PURE; + STDMETHOD(AddModelInput) + (_In_ const char* const name, _In_ IDescriptorInfoProvider* descriptor_provider, bool is_constant, IValue* default_value) PURE; - STDMETHOD(IsOfVectorTensorType) - (winml::TensorKind kind, bool* out) PURE; + STDMETHOD(AddModelOutput) + (_In_ const char* const name, _In_ IDescriptorInfoProvider* descriptor_provider) PURE; }; MIDL_INTERFACE("30c99886-38d2-41cb-a615-203fe7d7daac") IEngine : IUnknown { STDMETHOD(LoadModel) - (_In_ IModel*) PURE; + (_In_ IModel*)PURE; STDMETHOD(Initialize) () PURE; @@ -189,6 +226,9 @@ IEngineFactory : IUnknown { STDMETHOD(CreateModel) (_In_ void* data, _In_ size_t size, _Outptr_ IModel** out) PURE; + STDMETHOD(CreateEmptyModel) + (_In_ int64_t opset, _Outptr_ IModel * *out) PURE; + STDMETHOD(CreateEngineBuilder) (_Outptr_ IEngineBuilder **engine_builder) PURE; @@ -197,6 +237,19 @@ IEngineFactory : IUnknown { STDMETHOD(CreateCustomRegistry) (_Out_ IMLOperatorRegistry **registry) PURE; + + STDMETHOD(CreateTensorDescriptorInfo) + ( + winml::TensorKind kind, + int64_t* dims, + size_t num_dims, + _Out_ IDescriptorInfo **info) PURE; + + STDMETHOD(CreateSequenceDescriptorInfo) + (_Out_ IDescriptorInfo **info) PURE; + + STDMETHOD(CreateMapDescriptorInfo) + (_Out_ IDescriptorInfo **info) PURE; }; } // namespace _winml diff --git a/winml/test/api/LearningModelSessionAPITest.cpp b/winml/test/api/LearningModelSessionAPITest.cpp index 859900d0d2acc..78a374e49475c 100644 --- a/winml/test/api/LearningModelSessionAPITest.cpp +++ b/winml/test/api/LearningModelSessionAPITest.cpp @@ -17,8 +17,20 @@ using namespace winrt; using namespace winml; using namespace wfc; +#ifndef BUILD_INBOX +// experimental +using namespace winml_experimental; +using Operator = winml_experimental::LearningModelOperator; + +static const wchar_t MS_EXPERIMENTAL_DOMAIN[] = L"com.microsoft.experimental"; +#endif + using wf::IPropertyValue; +#define INT64(x) static_cast(x) +#define SIZET(x) static_cast(x) +#define INT32(x) static_cast(x) + static void LearningModelSessionAPITestsClassSetup() { init_apartment(); #ifdef BUILD_INBOX @@ -32,61 +44,57 @@ static void CreateSessionDeviceDefault() LearningModelDevice learningModelDevice = nullptr; WINML_EXPECT_NO_THROW(APITest::LoadModel(L"model.onnx", learningModel)); - WINML_EXPECT_NO_THROW(learningModelDevice = LearningModelDevice(LearningModelDeviceKind::Default)); - WINML_EXPECT_NO_THROW(LearningModelSession(learningModel, learningModelDevice)); + WINML_EXPECT_NO_THROW(learningModelDevice = LearningModelDevice(LearningModelDeviceKind::Default)); + WINML_EXPECT_NO_THROW(LearningModelSession(learningModel, learningModelDevice)); } -static void CreateSessionDeviceCpu() -{ - LearningModel learningModel = nullptr; - LearningModelDevice learningModelDevice = nullptr; - WINML_EXPECT_NO_THROW(APITest::LoadModel(L"model.onnx", learningModel)); +static void CreateSessionDeviceCpu() { + LearningModel learningModel = nullptr; + LearningModelDevice learningModelDevice = nullptr; + WINML_EXPECT_NO_THROW(APITest::LoadModel(L"model.onnx", learningModel)); - WINML_EXPECT_NO_THROW(learningModelDevice = LearningModelDevice(LearningModelDeviceKind::Cpu)); - WINML_EXPECT_NO_THROW(LearningModelSession(learningModel, learningModelDevice)); - // for the CPU device, make sure that we get back NULL and 0 for any device properties - WINML_EXPECT_EQUAL(learningModelDevice.Direct3D11Device(), nullptr); - LARGE_INTEGER id; - id.QuadPart = APITest::GetAdapterIdQuadPart(learningModelDevice); - WINML_EXPECT_EQUAL(id.LowPart, static_cast(0)); - WINML_EXPECT_EQUAL(id.HighPart, 0); + WINML_EXPECT_NO_THROW(learningModelDevice = LearningModelDevice(LearningModelDeviceKind::Cpu)); + WINML_EXPECT_NO_THROW(LearningModelSession(learningModel, learningModelDevice)); + // for the CPU device, make sure that we get back NULL and 0 for any device properties + WINML_EXPECT_EQUAL(learningModelDevice.Direct3D11Device(), nullptr); + LARGE_INTEGER id; + id.QuadPart = APITest::GetAdapterIdQuadPart(learningModelDevice); + WINML_EXPECT_EQUAL(id.LowPart, static_cast(0)); + WINML_EXPECT_EQUAL(id.HighPart, 0); } static void CreateSessionWithModelLoadedFromStream() { - LearningModel learningModel = nullptr; - LearningModelDevice learningModelDevice = nullptr; - std::wstring path = FileHelpers::GetModulePath() + L"model.onnx"; - auto storageFile = ws::StorageFile::GetFileFromPathAsync(path).get(); + LearningModel learningModel = nullptr; + LearningModelDevice learningModelDevice = nullptr; + std::wstring path = FileHelpers::GetModulePath() + L"model.onnx"; + auto storageFile = ws::StorageFile::GetFileFromPathAsync(path).get(); - WINML_EXPECT_NO_THROW(learningModel = LearningModel::LoadFromStream(storageFile)); + WINML_EXPECT_NO_THROW(learningModel = LearningModel::LoadFromStream(storageFile)); - WINML_EXPECT_NO_THROW(learningModelDevice = LearningModelDevice(LearningModelDeviceKind::Default)); - WINML_EXPECT_NO_THROW(LearningModelSession(learningModel, learningModelDevice)); + WINML_EXPECT_NO_THROW(learningModelDevice = LearningModelDevice(LearningModelDeviceKind::Default)); + WINML_EXPECT_NO_THROW(LearningModelSession(learningModel, learningModelDevice)); } -static void CreateSessionDeviceDirectX() -{ - LearningModel learningModel = nullptr; - LearningModelDevice learningModelDevice = nullptr; - WINML_EXPECT_NO_THROW(APITest::LoadModel(L"model.onnx", learningModel)); +static void CreateSessionDeviceDirectX() { + LearningModel learningModel = nullptr; + LearningModelDevice learningModelDevice = nullptr; + WINML_EXPECT_NO_THROW(APITest::LoadModel(L"model.onnx", learningModel)); - WINML_EXPECT_NO_THROW(learningModelDevice = LearningModelDevice(LearningModelDeviceKind::DirectX)); - WINML_EXPECT_NO_THROW(LearningModelSession(learningModel, learningModelDevice)); + WINML_EXPECT_NO_THROW(learningModelDevice = LearningModelDevice(LearningModelDeviceKind::DirectX)); + WINML_EXPECT_NO_THROW(LearningModelSession(learningModel, learningModelDevice)); } -static void CreateSessionDeviceDirectXHighPerformance() -{ - LearningModel learningModel = nullptr; - LearningModelDevice learningModelDevice = nullptr; - WINML_EXPECT_NO_THROW(APITest::LoadModel(L"model.onnx", learningModel)); +static void CreateSessionDeviceDirectXHighPerformance() { + LearningModel learningModel = nullptr; + LearningModelDevice learningModelDevice = nullptr; + WINML_EXPECT_NO_THROW(APITest::LoadModel(L"model.onnx", learningModel)); - WINML_EXPECT_NO_THROW(learningModelDevice = LearningModelDevice(LearningModelDeviceKind::DirectXHighPerformance)); - WINML_EXPECT_NO_THROW(LearningModelSession(learningModel, learningModelDevice)); + WINML_EXPECT_NO_THROW(learningModelDevice = LearningModelDevice(LearningModelDeviceKind::DirectXHighPerformance)); + WINML_EXPECT_NO_THROW(LearningModelSession(learningModel, learningModelDevice)); } -static void CreateSessionDeviceDirectXMinimumPower() -{ +static void CreateSessionDeviceDirectXMinimumPower() { LearningModel learningModel = nullptr; LearningModelDevice learningModelDevice = nullptr; WINML_EXPECT_NO_THROW(APITest::LoadModel(L"model.onnx", learningModel)); @@ -96,162 +104,153 @@ static void CreateSessionDeviceDirectXMinimumPower() } static void AdapterIdAndDevice() { - LearningModel learningModel = nullptr; - LearningModelDevice learningModelDevice = nullptr; - LearningModelSession learningModelSession = nullptr; - WINML_EXPECT_NO_THROW(APITest::LoadModel(L"model.onnx", learningModel)); + LearningModel learningModel = nullptr; + LearningModelDevice learningModelDevice = nullptr; + LearningModelSession learningModelSession = nullptr; + WINML_EXPECT_NO_THROW(APITest::LoadModel(L"model.onnx", learningModel)); - com_ptr factory; - WINML_EXPECT_HRESULT_SUCCEEDED(CreateDXGIFactory1(__uuidof(IDXGIFactory6), factory.put_void())); - com_ptr adapter; - - learningModelDevice = LearningModelDevice(LearningModelDeviceKind::DirectX); - WINML_EXPECT_HRESULT_SUCCEEDED(factory->EnumAdapters(0, adapter.put())); - DXGI_ADAPTER_DESC desc; - WINML_EXPECT_HRESULT_SUCCEEDED(adapter->GetDesc(&desc)); - LARGE_INTEGER id; - id.QuadPart = APITest::GetAdapterIdQuadPart(learningModelDevice); - WINML_EXPECT_EQUAL(desc.AdapterLuid.LowPart, id.LowPart); - WINML_EXPECT_EQUAL(desc.AdapterLuid.HighPart, id.HighPart); - WINML_EXPECT_TRUE(learningModelDevice.Direct3D11Device() != nullptr); - - learningModelDevice = LearningModelDevice(LearningModelDeviceKind::DirectXHighPerformance); - adapter = nullptr; - WINML_EXPECT_HRESULT_SUCCEEDED(factory->EnumAdapterByGpuPreference(0, DXGI_GPU_PREFERENCE_HIGH_PERFORMANCE, __uuidof(IDXGIAdapter), adapter.put_void())); - WINML_EXPECT_HRESULT_SUCCEEDED(adapter->GetDesc(&desc)); - id.QuadPart = APITest::GetAdapterIdQuadPart(learningModelDevice); - WINML_EXPECT_EQUAL(desc.AdapterLuid.LowPart, id.LowPart); - WINML_EXPECT_EQUAL(desc.AdapterLuid.HighPart, id.HighPart); - WINML_EXPECT_TRUE(learningModelDevice.Direct3D11Device() != nullptr); - - adapter = nullptr; - learningModelDevice = LearningModelDevice(LearningModelDeviceKind::DirectXMinPower); - WINML_EXPECT_HRESULT_SUCCEEDED(factory->EnumAdapterByGpuPreference(0, DXGI_GPU_PREFERENCE_MINIMUM_POWER, __uuidof(IDXGIAdapter), adapter.put_void())); - WINML_EXPECT_HRESULT_SUCCEEDED(adapter->GetDesc(&desc)); - id.QuadPart = APITest::GetAdapterIdQuadPart(learningModelDevice); - WINML_EXPECT_EQUAL(desc.AdapterLuid.LowPart, id.LowPart); - WINML_EXPECT_EQUAL(desc.AdapterLuid.HighPart, id.HighPart); - WINML_EXPECT_TRUE(learningModelDevice.Direct3D11Device() != nullptr); - - WINML_EXPECT_NO_THROW(learningModelSession = LearningModelSession(learningModel, learningModelDevice)); - WINML_EXPECT_EQUAL(learningModelSession.Device().AdapterId(), learningModelDevice.AdapterId()); -} - -static void EvaluateFeatures() -{ - std::vector shape = { 4 }; - std::vector data = { L"one", L"two", L"three", L"four" }; + com_ptr factory; + WINML_EXPECT_HRESULT_SUCCEEDED(CreateDXGIFactory1(__uuidof(IDXGIFactory6), factory.put_void())); + com_ptr adapter; + + learningModelDevice = LearningModelDevice(LearningModelDeviceKind::DirectX); + WINML_EXPECT_HRESULT_SUCCEEDED(factory->EnumAdapters(0, adapter.put())); + DXGI_ADAPTER_DESC desc; + WINML_EXPECT_HRESULT_SUCCEEDED(adapter->GetDesc(&desc)); + LARGE_INTEGER id; + id.QuadPart = APITest::GetAdapterIdQuadPart(learningModelDevice); + WINML_EXPECT_EQUAL(desc.AdapterLuid.LowPart, id.LowPart); + WINML_EXPECT_EQUAL(desc.AdapterLuid.HighPart, id.HighPart); + WINML_EXPECT_TRUE(learningModelDevice.Direct3D11Device() != nullptr); + + learningModelDevice = LearningModelDevice(LearningModelDeviceKind::DirectXHighPerformance); + adapter = nullptr; + WINML_EXPECT_HRESULT_SUCCEEDED(factory->EnumAdapterByGpuPreference(0, DXGI_GPU_PREFERENCE_HIGH_PERFORMANCE, __uuidof(IDXGIAdapter), adapter.put_void())); + WINML_EXPECT_HRESULT_SUCCEEDED(adapter->GetDesc(&desc)); + id.QuadPart = APITest::GetAdapterIdQuadPart(learningModelDevice); + WINML_EXPECT_EQUAL(desc.AdapterLuid.LowPart, id.LowPart); + WINML_EXPECT_EQUAL(desc.AdapterLuid.HighPart, id.HighPart); + WINML_EXPECT_TRUE(learningModelDevice.Direct3D11Device() != nullptr); + + adapter = nullptr; + learningModelDevice = LearningModelDevice(LearningModelDeviceKind::DirectXMinPower); + WINML_EXPECT_HRESULT_SUCCEEDED(factory->EnumAdapterByGpuPreference(0, DXGI_GPU_PREFERENCE_MINIMUM_POWER, __uuidof(IDXGIAdapter), adapter.put_void())); + WINML_EXPECT_HRESULT_SUCCEEDED(adapter->GetDesc(&desc)); + id.QuadPart = APITest::GetAdapterIdQuadPart(learningModelDevice); + WINML_EXPECT_EQUAL(desc.AdapterLuid.LowPart, id.LowPart); + WINML_EXPECT_EQUAL(desc.AdapterLuid.HighPart, id.HighPart); + WINML_EXPECT_TRUE(learningModelDevice.Direct3D11Device() != nullptr); + + WINML_EXPECT_NO_THROW(learningModelSession = LearningModelSession(learningModel, learningModelDevice)); + WINML_EXPECT_EQUAL(learningModelSession.Device().AdapterId(), learningModelDevice.AdapterId()); +} - // create from buffer - auto tensor = TensorString::CreateFromArray(shape, data); - WINML_EXPECT_EQUAL(tensor.GetAsVectorView().Size(), data.size()); - WINML_EXPECT_TRUE(std::equal(data.cbegin(), data.cend(), begin(tensor.GetAsVectorView()))); +static void EvaluateFeatures() { + std::vector shape = {4}; + std::vector data = {L"one", L"two", L"three", L"four"}; - // create from vector view - auto dataCopy = data; - tensor = TensorString::CreateFromIterable( - shape, winrt::single_threaded_vector(std::move(dataCopy)).GetView()); - WINML_EXPECT_EQUAL(tensor.GetAsVectorView().Size(), data.size()); - WINML_EXPECT_TRUE(std::equal(data.cbegin(), data.cend(), begin(tensor.GetAsVectorView()))); + // create from buffer + auto tensor = TensorString::CreateFromArray(shape, data); + WINML_EXPECT_EQUAL(tensor.GetAsVectorView().Size(), data.size()); + WINML_EXPECT_TRUE(std::equal(data.cbegin(), data.cend(), begin(tensor.GetAsVectorView()))); - LearningModel learningModel = nullptr; - WINML_EXPECT_NO_THROW(APITest::LoadModel(L"id-tensor-string.onnx", learningModel)); - LearningModelSession session(learningModel); + // create from vector view + auto dataCopy = data; + tensor = TensorString::CreateFromIterable( + shape, winrt::single_threaded_vector(std::move(dataCopy)).GetView()); + WINML_EXPECT_EQUAL(tensor.GetAsVectorView().Size(), data.size()); + WINML_EXPECT_TRUE(std::equal(data.cbegin(), data.cend(), begin(tensor.GetAsVectorView()))); - auto outputTensor = TensorString::Create(); + LearningModel learningModel = nullptr; + WINML_EXPECT_NO_THROW(APITest::LoadModel(L"id-tensor-string.onnx", learningModel)); + LearningModelSession session(learningModel); - std::map featuresstandardmap; - featuresstandardmap[L"X"] = tensor; - featuresstandardmap[L"Y"] = outputTensor; - auto featureswinrtmap = winrt::single_threaded_map(std::move(featuresstandardmap)); - session.EvaluateFeatures(featureswinrtmap, L"0"); + auto outputTensor = TensorString::Create(); - // verify identity model round-trip works - WINML_EXPECT_EQUAL(outputTensor.GetAsVectorView().Size(), data.size()); - WINML_EXPECT_TRUE(std::equal(data.cbegin(), data.cend(), begin(outputTensor.GetAsVectorView()))); + std::map featuresstandardmap; + featuresstandardmap[L"X"] = tensor; + featuresstandardmap[L"Y"] = outputTensor; + auto featureswinrtmap = winrt::single_threaded_map(std::move(featuresstandardmap)); + session.EvaluateFeatures(featureswinrtmap, L"0"); + + // verify identity model round-trip works + WINML_EXPECT_EQUAL(outputTensor.GetAsVectorView().Size(), data.size()); + WINML_EXPECT_TRUE(std::equal(data.cbegin(), data.cend(), begin(outputTensor.GetAsVectorView()))); } -static void EvaluateFeaturesAsync() -{ - std::vector shape = { 4 }; - std::vector data = { L"one", L"two", L"three", L"four" }; +static void EvaluateFeaturesAsync() { + std::vector shape = {4}; + std::vector data = {L"one", L"two", L"three", L"four"}; - // create from buffer - auto tensor = TensorString::CreateFromArray(shape, data); - WINML_EXPECT_EQUAL(tensor.GetAsVectorView().Size(), data.size()); - WINML_EXPECT_TRUE(std::equal(data.cbegin(), data.cend(), begin(tensor.GetAsVectorView()))); + // create from buffer + auto tensor = TensorString::CreateFromArray(shape, data); + WINML_EXPECT_EQUAL(tensor.GetAsVectorView().Size(), data.size()); + WINML_EXPECT_TRUE(std::equal(data.cbegin(), data.cend(), begin(tensor.GetAsVectorView()))); - // create from vector view - auto dataCopy = data; - tensor = TensorString::CreateFromIterable( - shape, winrt::single_threaded_vector(std::move(dataCopy)).GetView()); - WINML_EXPECT_EQUAL(tensor.GetAsVectorView().Size(), data.size()); - WINML_EXPECT_TRUE(std::equal(data.cbegin(), data.cend(), begin(tensor.GetAsVectorView()))); + // create from vector view + auto dataCopy = data; + tensor = TensorString::CreateFromIterable( + shape, winrt::single_threaded_vector(std::move(dataCopy)).GetView()); + WINML_EXPECT_EQUAL(tensor.GetAsVectorView().Size(), data.size()); + WINML_EXPECT_TRUE(std::equal(data.cbegin(), data.cend(), begin(tensor.GetAsVectorView()))); - LearningModel learningModel = nullptr; - WINML_EXPECT_NO_THROW(APITest::LoadModel(L"id-tensor-string.onnx", learningModel)); - LearningModelSession session(learningModel); + LearningModel learningModel = nullptr; + WINML_EXPECT_NO_THROW(APITest::LoadModel(L"id-tensor-string.onnx", learningModel)); + LearningModelSession session(learningModel); - auto outputTensor = TensorString::Create(shape); + auto outputTensor = TensorString::Create(shape); - std::map featuresstandardmap; - featuresstandardmap[L"X"] = tensor; - featuresstandardmap[L"Y"] = outputTensor; - auto featureswinrtmap = winrt::single_threaded_map(std::move(featuresstandardmap)); - session.EvaluateFeaturesAsync(featureswinrtmap, L"0").get(); + std::map featuresstandardmap; + featuresstandardmap[L"X"] = tensor; + featuresstandardmap[L"Y"] = outputTensor; + auto featureswinrtmap = winrt::single_threaded_map(std::move(featuresstandardmap)); + session.EvaluateFeaturesAsync(featureswinrtmap, L"0").get(); - // verify identity model round-trip works - WINML_EXPECT_EQUAL(outputTensor.GetAsVectorView().Size(), data.size()); - WINML_EXPECT_TRUE(std::equal(data.cbegin(), data.cend(), begin(outputTensor.GetAsVectorView()))); + // verify identity model round-trip works + WINML_EXPECT_EQUAL(outputTensor.GetAsVectorView().Size(), data.size()); + WINML_EXPECT_TRUE(std::equal(data.cbegin(), data.cend(), begin(outputTensor.GetAsVectorView()))); } -static void EvaluationProperties() -{ - // load a model - LearningModel learningModel = nullptr; - WINML_EXPECT_NO_THROW(APITest::LoadModel(L"model.onnx", learningModel)); - // create a session - LearningModelSession learningModelSession = nullptr; - learningModelSession = LearningModelSession(learningModel); - // set a property - auto value = wf::PropertyValue::CreateBoolean(true); - learningModelSession.EvaluationProperties().Insert(L"propName1", value); - // get the property and make sure it's there with the right value - auto value2 = learningModelSession.EvaluationProperties().Lookup(L"propName1"); - WINML_EXPECT_EQUAL(value2.as().GetBoolean(), true); -} - -static LearningModelSession CreateSession(LearningModel model) -{ - LearningModelDevice device(nullptr); - WINML_EXPECT_NO_THROW(device = LearningModelDevice(LearningModelDeviceKind::DirectX)); +static void EvaluationProperties() { + // load a model + LearningModel learningModel = nullptr; + WINML_EXPECT_NO_THROW(APITest::LoadModel(L"model.onnx", learningModel)); + // create a session + LearningModelSession learningModelSession = nullptr; + learningModelSession = LearningModelSession(learningModel); + // set a property + auto value = winrt::Windows::Foundation::PropertyValue::CreateBoolean(true); + learningModelSession.EvaluationProperties().Insert(L"propName1", value); + // get the property and make sure it's there with the right value + auto value2 = learningModelSession.EvaluationProperties().Lookup(L"propName1"); + WINML_EXPECT_EQUAL(value2.as().GetBoolean(), true); +} - LearningModelSession session(nullptr); - if (CommonDeviceHelpers::IsFloat16Supported(device)) - { - WINML_EXPECT_NO_THROW(session = LearningModelSession(model, device)); - } - else - { - WINML_EXPECT_THROW_SPECIFIC( - session = LearningModelSession(model, device), - winrt::hresult_error, - [](const winrt::hresult_error& e) -> bool - { - return e.code() == DXGI_ERROR_UNSUPPORTED; +static LearningModelSession CreateSession(LearningModel model) { + LearningModelDevice device(nullptr); + WINML_EXPECT_NO_THROW(device = LearningModelDevice(LearningModelDeviceKind::DirectX)); + + LearningModelSession session(nullptr); + if (CommonDeviceHelpers::IsFloat16Supported(device)) { + WINML_EXPECT_NO_THROW(session = LearningModelSession(model, device)); + } else { + WINML_EXPECT_THROW_SPECIFIC( + session = LearningModelSession(model, device), + winrt::hresult_error, + [](const winrt::hresult_error& e) -> bool { + return e.code() == DXGI_ERROR_UNSUPPORTED; }); - } + } - return session; + return session; } -static void CreateSessionWithCastToFloat16InModel() -{ - // load a model - LearningModel learningModel = nullptr; - WINML_EXPECT_NO_THROW(APITest::LoadModel(L"fp16-truncate-with-cast.onnx", learningModel)); +static void CreateSessionWithCastToFloat16InModel() { + // load a model + LearningModel learningModel = nullptr; + WINML_EXPECT_NO_THROW(APITest::LoadModel(L"fp16-truncate-with-cast.onnx", learningModel)); - CreateSession(learningModel); + CreateSession(learningModel); } static void CreateSessionWithFloat16InitializersInModel() @@ -260,58 +259,52 @@ static void CreateSessionWithFloat16InitializersInModel() LearningModel learningModel = nullptr; WINML_EXPECT_NO_THROW(APITest::LoadModel(L"fp16-initializer.onnx", learningModel)); - CreateSession(learningModel); + CreateSession(learningModel); } static void EvaluateSessionAndCloseModelHelper( LearningModelDeviceKind kind, - bool close_model_on_session_creation) -{ - auto shape = std::vector{ 1, 1000 }; + bool close_model_on_session_creation) { + auto shape = std::vector{1, 1000}; - auto model = ProtobufHelpers::CreateModel(TensorKind::Float, shape, 1000); + auto model = ProtobufHelpers::CreateModel(TensorKind::Float, shape, 1000); - auto device = LearningModelDevice(kind); - auto options = LearningModelSessionOptions(); + auto device = LearningModelDevice(kind); + auto options = LearningModelSessionOptions(); - // close the model on session creation - options.CloseModelOnSessionCreation(close_model_on_session_creation); + // close the model on session creation + options.CloseModelOnSessionCreation(close_model_on_session_creation); - // ensure you can create a session from the model - LearningModelSession session(nullptr); + // ensure you can create a session from the model + LearningModelSession session(nullptr); - WINML_EXPECT_NO_THROW(session = LearningModelSession(model, device, options)); + WINML_EXPECT_NO_THROW(session = LearningModelSession(model, device, options)); - std::vector input(1000); - std::iota(std::begin(input), std::end(input), 0.0f); - auto tensor_input = TensorFloat::CreateFromShapeArrayAndDataArray(shape, input); - auto binding = LearningModelBinding(session); - binding.Bind(L"input", tensor_input); + std::vector input(1000); + std::iota(std::begin(input), std::end(input), 0.0f); + auto tensor_input = TensorFloat::CreateFromArray(shape, input); + auto binding = LearningModelBinding(session); + binding.Bind(L"input", tensor_input); - LearningModelEvaluationResult result(nullptr); - WINML_EXPECT_NO_THROW(result = session.Evaluate(binding, L"")); + LearningModelEvaluationResult result(nullptr); + WINML_EXPECT_NO_THROW(result = session.Evaluate(binding, L"")); - if (close_model_on_session_creation) - { - // ensure that the model has been closed - WINML_EXPECT_THROW_SPECIFIC( - LearningModelSession(model, device, options), - winrt::hresult_error, - [](const winrt::hresult_error& e) -> bool - { - return e.code() == E_INVALIDARG; + if (close_model_on_session_creation) { + // ensure that the model has been closed + WINML_EXPECT_THROW_SPECIFIC( + LearningModelSession(model, device, options), + winrt::hresult_error, + [](const winrt::hresult_error& e) -> bool { + return e.code() == E_INVALIDARG; }); - } - else - { - WINML_EXPECT_NO_THROW(LearningModelSession(model, device, options)); - } + } else { + WINML_EXPECT_NO_THROW(LearningModelSession(model, device, options)); + } } -static void EvaluateSessionAndCloseModel() -{ - WINML_EXPECT_NO_THROW(::EvaluateSessionAndCloseModelHelper(LearningModelDeviceKind::Cpu, true)); - WINML_EXPECT_NO_THROW(::EvaluateSessionAndCloseModelHelper(LearningModelDeviceKind::Cpu, false)); +static void EvaluateSessionAndCloseModel() { + WINML_EXPECT_NO_THROW(::EvaluateSessionAndCloseModelHelper(LearningModelDeviceKind::Cpu, true)); + WINML_EXPECT_NO_THROW(::EvaluateSessionAndCloseModelHelper(LearningModelDeviceKind::Cpu, false)); } static void NamedDimensionOverride() @@ -365,7 +358,7 @@ static void CloseSession() WINML_EXPECT_NO_THROW(APITest::LoadModel(L"model.onnx", learningModel)); LearningModelSession session = nullptr; - /* + /* HANDLE currentProcessHandle = NULL; try { @@ -380,9 +373,9 @@ static void CloseSession() SIZE_T afterSessionCloseWorkingSetSize = 0; bool getProcessMemoryInfoSuccess = false; */ - WINML_EXPECT_NO_THROW(session = LearningModelSession(learningModel)); + WINML_EXPECT_NO_THROW(session = LearningModelSession(learningModel)); - /* + /* // Get the current process memory info after session creation. getProcessMemoryInfoSuccess = GetProcessMemoryInfo(currentProcessHandle, &pmc, sizeof(pmc)); if (!getProcessMemoryInfoSuccess) @@ -392,9 +385,9 @@ static void CloseSession() beforeSessionCloseWorkingSetSize = pmc.WorkingSetSize; pmc = { 0 }; */ - WINML_EXPECT_NO_THROW(session.Close()); + WINML_EXPECT_NO_THROW(session.Close()); - /* + /* Bug 23659026: Working set difference tolerance is too tight for LearningModelSessionAPITests::CloseSession https://microsoft.visualstudio.com/OS/_workitems/edit/23659026 @@ -417,21 +410,562 @@ static void CloseSession() VERIFY_IS_LESS_THAN(expectedWorkingSetDifference - (beforeSessionCloseWorkingSetSize - afterSessionCloseWorkingSetSize), expectedWorkingSetDifference * tolerance); */ - // verify that model still has metadata info after session close - std::wstring author(learningModel.Author()); - WINML_EXPECT_EQUAL(author, L"onnx-caffe2"); + // verify that model still has metadata info after session close + std::wstring author(learningModel.Author()); + WINML_EXPECT_EQUAL(author, L"onnx-caffe2"); + + // verify that session throws RO_E_CLOSED error + std::vector input(1 * 3 * 224 * 224, 0); + std::vector shape = {1, 3, 224, 224}; + auto tensor_input = TensorFloat::CreateFromArray(shape, input); + WINML_EXPECT_THROW_SPECIFIC(LearningModelBinding binding(session), + winrt::hresult_error, + [](const winrt::hresult_error& e) -> bool { + return e.code() == RO_E_CLOSED; + }); +} - // verify that session throws RO_E_CLOSED error - std::vector input(1 * 3 * 224 * 224, 0); - std::vector shape = { 1, 3, 224, 224 }; - auto tensor_input = TensorFloat::CreateFromShapeArrayAndDataArray(shape, input); - WINML_EXPECT_THROW_SPECIFIC(LearningModelBinding binding(session), - winrt::hresult_error, - [](const winrt::hresult_error &e) -> bool - { - return e.code() == RO_E_CLOSED; - }); - } +#if !defined(BUILD_INBOX) && defined(BUILD_MS_EXPERIMENTAL_OPS) +static void WindowFunction(const wchar_t* window_operator_name, TensorKind kind) { + std::vector scalar_shape = {}; + std::vector output_shape = {32}; + auto double_data_type = TensorInt64Bit::CreateFromArray({}, {11}); + + auto window_operator = + Operator(window_operator_name, MS_EXPERIMENTAL_DOMAIN) + .SetInput(L"size", L"Input") + .SetOutput(L"output", L"Output"); + + if (kind == TensorKind::Double) { + window_operator.SetAttribute(L"output_datatype", double_data_type); + } + + auto model = + LearningModelBuilder::Create(13) + .Inputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"Input", TensorKind::Int64, scalar_shape)) + .Outputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"Output", kind, output_shape)) + .Operators().Add(window_operator) + .CreateModel(); + + LearningModelSession session(model); + LearningModelBinding binding(session); + + binding.Bind(L"Input", TensorInt64Bit::CreateFromArray(scalar_shape, {32})); + + // Evaluate + auto result = session.Evaluate(binding, L""); + + // Check results + printf("Output\n"); + if (kind == TensorKind::Float) { + auto y_tensor = result.Outputs().Lookup(L"Output").as(); + auto y_ivv = y_tensor.GetAsVectorView(); + for (int i = 0; i < output_shape[0]; i++) { + printf("%f, ", y_ivv.GetAt(i)); + } + } + if (kind == TensorKind::Double) { + auto y_tensor = result.Outputs().Lookup(L"Output").as(); + auto y_ivv = y_tensor.GetAsVectorView(); + for (int i = 0; i < output_shape[0]; i++) { + printf("%f, ", y_ivv.GetAt(i)); + } + } + printf("\n"); +} +#endif + +#if !defined(BUILD_INBOX) && defined(BUILD_MS_EXPERIMENTAL_OPS) +static void DiscreteFourierTransform(bool is_onesided = false) { + std::vector shape = {1, 5}; + std::vector output_shape = {1, 5, 2}; + output_shape[1] = is_onesided ? (1 + (shape[1] >> 1)) : shape[1]; + + auto model = + LearningModelBuilder::Create(13) + .Inputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"Input.Signal", TensorKind::Float, shape)) + .Outputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"Output.Spectra", TensorKind::Float, output_shape)) + .Operators().Add(Operator(L"DFT", MS_EXPERIMENTAL_DOMAIN) + .SetInput(L"input", L"Input.Signal") + .SetAttribute(L"onesided", TensorInt64Bit::CreateFromArray({}, {is_onesided})) + .SetOutput(L"output", L"Output.Spectra")) + .CreateModel(); + + LearningModelSession session(model); + LearningModelBinding binding(session); + + // Populate binding + binding.Bind(L"Input.Signal", TensorFloat::CreateFromArray(shape, {1, 2, 3, 4, 5})); + + // Evaluate + auto result = session.Evaluate(binding, L""); + + // Check results + printf("Output.Spectra\n"); + auto y_tensor = result.Outputs().Lookup(L"Output.Spectra").as(); + auto y_ivv = y_tensor.GetAsVectorView(); + for (int i = 0; i < output_shape[0] * output_shape[1] * 2; i += 2) { + printf("(%f + %fi), ", y_ivv.GetAt(i), y_ivv.GetAt(i + 1)); + } + printf("\n"); +} +#endif + +template +static auto MakePureFrequency(float frequency_in_hertz, size_t signal_size, size_t sample_rate) { + float amplitude = 4; + float angular_velocity = frequency_in_hertz * 2 * 3.1415f; + std::vector signal(signal_size); + for (size_t i = 0; i < signal_size; i++) { + T time = i / static_cast(sample_rate); + signal[i] = amplitude * cos(angular_velocity * time); + } + return signal; +} + +template +static auto MakeMiddleC(size_t signal_size, size_t sample_rate) { + float middle_c_in_hertz = 261.626f; + return MakePureFrequency(middle_c_in_hertz, signal_size, sample_rate); +} + +template +static auto MakeC2(size_t signal_size, size_t sample_rate) { + float middle_c_in_hertz = 261.626f * 2; + return MakePureFrequency(middle_c_in_hertz, signal_size, sample_rate); +} + +template +static auto MakeC4(size_t signal_size, size_t sample_rate) { + float middle_c_in_hertz = 261.626f * 4; + return MakePureFrequency(middle_c_in_hertz, signal_size, sample_rate); +} + +template +static auto MakeThreeTones(size_t signal_size, size_t sample_rate) { + auto middle_c = MakeMiddleC(signal_size, sample_rate); + auto c2 = MakeC2(signal_size, sample_rate); + auto c4 = MakeC4(signal_size, sample_rate); + for (size_t i = 0; i < signal_size; i++) { + middle_c[i] = (i < signal_size / 3) ? + middle_c[i] : + (i < 2*signal_size/3) ? + (middle_c[i] + c2[i]) : + (middle_c[i] + c2[i] + c4[i]); + } + return middle_c; +} + +#if !defined(BUILD_INBOX) && defined(BUILD_MS_EXPERIMENTAL_OPS) +static void STFT(size_t batch_size, size_t signal_size, size_t dft_size, + size_t hop_size, size_t sample_rate, bool is_onesided = false) { + auto n_dfts = static_cast(1 + floor((signal_size - dft_size) / hop_size)); + auto input_shape = std::vector{1, INT64(signal_size)}; + auto output_shape = + std::vector{ + INT64(batch_size), + INT64(n_dfts), + is_onesided ? ((INT64(dft_size) >> 1) + 1) : INT64(dft_size), + 2 + }; + auto dft_length = TensorInt64Bit::CreateFromArray({}, {INT64(dft_size)}); + + auto model = + LearningModelBuilder::Create(13) + .Inputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"Input.TimeSignal", TensorKind::Float, input_shape)) + .Outputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"Output.STFT", TensorKind::Float, output_shape)) + .Outputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"Output.HannWindow", TensorKind::Float, {INT64(dft_size)})) + .Operators().Add(Operator(L"HannWindow", MS_EXPERIMENTAL_DOMAIN) + .SetConstant(L"size", dft_length) + .SetOutput(L"output", L"Output.HannWindow")) + .Operators().Add(Operator(L"STFT", MS_EXPERIMENTAL_DOMAIN) + .SetAttribute(L"onesided", TensorInt64Bit::CreateFromArray({}, {INT64(is_onesided)})) + .SetInput(L"signal", L"Input.TimeSignal") + .SetInput(L"window", L"Output.HannWindow") + .SetConstant(L"frame_length", dft_length) + .SetConstant(L"frame_step", TensorInt64Bit::CreateFromArray({}, {INT64(hop_size)})) + .SetOutput(L"output", L"Output.STFT")) + .CreateModel(); + + LearningModelSession session(model); + LearningModelBinding binding(session); + + // Create signal binding + auto signal = MakeMiddleC(signal_size, sample_rate); + printf("\n"); + printf("Input.TimeSignal:\n"); + for (size_t i = 0; i < dft_size; i++) { + printf("%f, ", signal[i]); + } + + // Bind + binding.Bind(L"Input.TimeSignal", TensorFloat::CreateFromArray(input_shape, signal)); + + // Evaluate + auto result = session.Evaluate(binding, L""); + + printf("\n"); + printf("Output.HannWindow\n"); + auto window_tensor = result.Outputs().Lookup(L"Output.HannWindow").as(); + auto window_ivv = window_tensor.GetAsVectorView(); + for (uint32_t i = 0; i < window_ivv.Size(); i++) { + printf("%f, ", window_ivv.GetAt(i)); + } + printf("\n"); + printf("Output.STFT\n"); + // Check results + auto y_tensor = result.Outputs().Lookup(L"Output.STFT").as(); + auto y_ivv = y_tensor.GetAsVectorView(); + auto size = y_ivv.Size(); + WINML_EXPECT_EQUAL(size, n_dfts * output_shape[2] * 2); + for (size_t dft_idx = 0; dft_idx < n_dfts; dft_idx++) { + for (size_t i = 0; INT64(i) < output_shape[2]; i++) { + auto real_idx = static_cast((i * 2) + (2 * dft_idx * output_shape[2])); + printf("(%d, %f , %fi), ", static_cast(i), y_ivv.GetAt(real_idx), y_ivv.GetAt(real_idx + 1)); + } + } + + printf("\n"); +} +#endif + +static void ModelBuilding_MelWeightMatrix() { +#if !defined(BUILD_INBOX) && defined(BUILD_MS_EXPERIMENTAL_OPS) + std::vector output_shape = {INT64(9), INT64(8)}; + auto builder = + LearningModelBuilder::Create(13) + .Outputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"Output.MelWeightMatrix", TensorKind::Float, output_shape)) + .Operators().Add(Operator(L"MelWeightMatrix", MS_EXPERIMENTAL_DOMAIN) + .SetConstant(L"num_mel_bins", TensorInt64Bit::CreateFromArray({}, {INT64(8)})) + .SetConstant(L"dft_length", TensorInt64Bit::CreateFromArray({}, {INT64(16)})) + .SetConstant(L"sample_rate", TensorInt64Bit::CreateFromArray({}, {INT64(8192)})) + .SetConstant(L"lower_edge_hertz", TensorFloat::CreateFromArray({}, {0})) + .SetConstant(L"upper_edge_hertz", TensorFloat::CreateFromArray({}, {8192 / 2.f})) + .SetOutput(L"output", L"Output.MelWeightMatrix")); + auto model = builder.CreateModel(); + + LearningModelSession session(model); + LearningModelBinding binding(session); + + auto result = session.Evaluate(binding, L""); + + printf("\n"); + printf("Output.MelWeightMatrix\n"); + { + auto y_tensor = result.Outputs().Lookup(L"Output.MelWeightMatrix").as(); + auto y_ivv = y_tensor.GetAsVectorView(); + for (unsigned i = 0; i < y_ivv.Size(); i++) { + printf("%f, ", y_ivv.GetAt(i)); + } + } + + printf("\n"); +#endif +} + +#if !defined(BUILD_INBOX) && defined(BUILD_MS_EXPERIMENTAL_OPS) +static void MelSpectrogramOnThreeToneSignal( + size_t batch_size, size_t signal_size, size_t window_size, size_t dft_size, + size_t hop_size, size_t n_mel_bins, size_t sampling_rate) { + auto n_dfts = static_cast(1 + floor((signal_size - dft_size) / hop_size)); + auto onesided_dft_size = (dft_size >> 1) + 1; + std::vector signal_shape = {INT64(batch_size), INT64(signal_size)}; + std::vector mel_spectrogram_shape = {INT64(batch_size), 1, INT64(n_dfts), INT64(n_mel_bins)}; + + auto builder = + LearningModelBuilder::Create(13) + .Inputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"Input.TimeSignal", TensorKind::Float, signal_shape)) + .Outputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"Output.MelSpectrogram", TensorKind::Float, mel_spectrogram_shape)) + .Operators().Add(Operator(L"HannWindow", MS_EXPERIMENTAL_DOMAIN) + .SetConstant(L"size", TensorInt64Bit::CreateFromArray({}, {INT64(window_size)})) + .SetOutput(L"output", L"hann_window")) + .Operators().Add(Operator(L"STFT", MS_EXPERIMENTAL_DOMAIN) + .SetName(L"STFT_NAMED_NODE") + .SetInput(L"signal", L"Input.TimeSignal") + .SetInput(L"window", L"hann_window") + .SetConstant(L"frame_length", TensorInt64Bit::CreateFromArray({}, {INT64(dft_size)})) + .SetConstant(L"frame_step", TensorInt64Bit::CreateFromArray({}, {INT64(hop_size)})) + .SetOutput(L"output", L"stft_output")) + .Operators().Add(Operator(L"ReduceSumSquare") + .SetInput(L"data", L"stft_output") + .SetAttribute(L"axes", TensorInt64Bit::CreateFromArray({1}, {3})) + .SetAttribute(L"keepdims", TensorInt64Bit::CreateFromArray({}, {0})) + .SetOutput(L"reduced", L"magnitude_squared")) + .Operators().Add(Operator(L"Div") + .SetInput(L"A", L"magnitude_squared") + .SetConstant(L"B", TensorFloat::CreateFromArray({}, {static_cast(dft_size)})) + .SetOutput(L"C", L"power_frames")) + .Operators().Add(Operator(L"MelWeightMatrix", MS_EXPERIMENTAL_DOMAIN) + .SetConstant(L"num_mel_bins", TensorInt64Bit::CreateFromArray({}, {INT64(n_mel_bins)})) + .SetConstant(L"dft_length", TensorInt64Bit::CreateFromArray({}, {INT64(dft_size)})) + .SetConstant(L"sample_rate", TensorInt64Bit::CreateFromArray({}, {INT64(sampling_rate)})) + .SetConstant(L"lower_edge_hertz", TensorFloat::CreateFromArray({}, {0})) + .SetConstant(L"upper_edge_hertz", TensorFloat::CreateFromArray({}, {sampling_rate / 2.f})) + .SetOutput(L"output", L"mel_weight_matrix")) + .Operators().Add(Operator(L"Reshape") + .SetInput(L"data", L"power_frames") + .SetConstant(L"shape", TensorInt64Bit::CreateFromArray({2}, {INT64(batch_size * n_dfts), INT64(onesided_dft_size)})) + .SetOutput(L"reshaped", L"reshaped_output")) + .Operators().Add(Operator(L"MatMul") + .SetInput(L"A", L"reshaped_output") + .SetInput(L"B", L"mel_weight_matrix") + .SetOutput(L"Y", L"mel_spectrogram")) + .Operators().Add(Operator(L"Reshape") + .SetInput(L"data", L"mel_spectrogram") + .SetConstant(L"shape", TensorInt64Bit::CreateFromArray({4}, mel_spectrogram_shape)) + .SetOutput(L"reshaped", L"Output.MelSpectrogram")); + auto model = builder.CreateModel(); + + LearningModelSession session(model); + LearningModelBinding binding(session); + + // Bind input + auto signal = MakeThreeTones(signal_size, sampling_rate); + binding.Bind(L"Input.TimeSignal", TensorFloat::CreateFromArray(signal_shape, signal)); + + // Bind output + auto output_image = + winrt::Windows::Media::VideoFrame( + winrt::Windows::Graphics::Imaging::BitmapPixelFormat::Bgra8, + INT32(n_mel_bins), + INT32(n_dfts)); + binding.Bind(L"Output.MelSpectrogram", output_image); + + // Evaluate + auto start = std::chrono::high_resolution_clock::now(); + auto result = session.Evaluate(binding, L""); + auto end = std::chrono::high_resolution_clock::now(); + std::chrono::duration evaluate_duration_in_microseconds = end - start; + printf("Evaluate Took: %f\n", evaluate_duration_in_microseconds.count()); + + // Check the output video frame object by saving output image to disk + std::wstring out_name = L"mel_spectrogram.jpg"; + + // Save the output + std::wstring modulePath = FileHelpers::GetModulePath(); + winrt::Windows::Storage::StorageFolder folder = winrt::Windows::Storage::StorageFolder::GetFolderFromPathAsync(modulePath).get(); + winrt::Windows::Storage::StorageFile file = folder.CreateFileAsync(out_name, winrt::Windows::Storage::CreationCollisionOption::ReplaceExisting).get(); + winrt::Windows::Storage::Streams::IRandomAccessStream write_stream = file.OpenAsync(winrt::Windows::Storage::FileAccessMode::ReadWrite).get(); + winrt::Windows::Graphics::Imaging::BitmapEncoder encoder = winrt::Windows::Graphics::Imaging::BitmapEncoder::CreateAsync(winrt::Windows::Graphics::Imaging::BitmapEncoder::JpegEncoderId(), write_stream).get(); + encoder.SetSoftwareBitmap(output_image.SoftwareBitmap()); + encoder.FlushAsync().get(); + + // Save the model + builder.Save(L"spectrogram.onnx"); + printf("\n"); +} +#endif + +static void ModelBuilding_StandardDeviationNormalization() { +#ifndef BUILD_INBOX + int64_t height = 256; + int64_t width = 256; + int64_t channels = 3; + std::vector input_shape = {1, height, width, channels}; + std::vector output_shape = {1, channels, height, width}; + LearningModelBuilder::Create(13) + .Inputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"Input", L"The NHWC image", TensorKind::Float, input_shape)) + .Inputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"Means", TensorKind::Float, {channels})) + .Inputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"StdDevs", TensorKind::Float, {channels})) + .Outputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"Output", L"The NCHW image normalized with mean and stddev.", TensorKind::Float, output_shape)) + .Operators().Add(Operator(L"Sub") + .SetInput(L"A", L"Input") + .SetInput(L"B", L"Means") + .SetOutput(L"C", L"SubOutput")) + .Operators().Add(Operator(L"Div") + .SetInput(L"A", L"SubOutput") + .SetInput(L"B", L"StdDevs") + .SetOutput(L"C", L"DivOutput")) + .Operators().Add(Operator(L"Transpose") + .SetInput(L"data", L"DivOutput") + .SetAttribute(L"perm", TensorInt64Bit::CreateFromArray({4}, {0,3,1,2})) + .SetOutput(L"transposed", L"Output")) + .Save(L"StandardDeviationNormalization.onnx"); + //.CreateModel(); +#endif +} + +static void ModelBuilding_Gemm() { +#ifndef BUILD_INBOX + std::vector shape = {3, 3}; + auto model = + LearningModelBuilder::Create(13) + .Inputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"InputA", TensorKind::Float, shape)) + .Inputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"InputB", TensorKind::Float, shape)) + .Inputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"InputC", TensorKind::Float, shape)) + .Outputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"OutputY", TensorKind::Float, shape)) + .Operators().Add(Operator(L"Gemm") + .SetInput(L"A", L"InputA") + .SetInput(L"B", L"InputB") + .SetInput(L"C", L"InputC") + .SetOutput(L"Y", L"OutputY")) + .CreateModel(); +#endif +} + +static void ModelBuilding_DynamicMatmul() { +#ifndef BUILD_INBOX + std::vector a_shape = {318, 129}; + std::vector b_shape = {129, 1024}; + + auto model = + LearningModelBuilder::Create(13) + .Inputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"InputA", TensorKind::Float, a_shape)) + .Inputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"InputB", TensorKind::Float, b_shape)) + .Outputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"Output", TensorKind::Float, {a_shape[0], b_shape[1]})) + .Operators().Add(Operator(L"MatMul") + .SetInput(L"A", L"InputA") + .SetInput(L"B", L"InputB") + .SetOutput(L"Y", L"Output")) + .CreateModel(); + + LearningModelSession session(model); + LearningModelBinding binding(session); + + // Bind A + auto a_matrix = std::vector(SIZET(a_shape[0] * a_shape[1]), 1); + binding.Bind(L"InputA", TensorFloat::CreateFromArray(a_shape, a_matrix)); + + // Bind B + auto b_matrix = std::vector(SIZET(b_shape[0] * b_shape[1]), 1); + binding.Bind(L"InputB", TensorFloat::CreateFromArray(b_shape, b_matrix)); + + // Evaluate + auto start = std::chrono::high_resolution_clock::now(); + auto result = session.Evaluate(binding, L""); + auto end = std::chrono::high_resolution_clock::now(); + + // Print duration + std::chrono::duration evaluate_duration_in_microseconds = end - start; + printf("Evaluate Took: %f\n", evaluate_duration_in_microseconds.count()); +#endif +} + +static void ModelBuilding_ConstantMatmul() { +#ifndef BUILD_INBOX + std::vector a_shape = {318, 129}; + std::vector b_shape = {129, 1024}; + + auto model = + LearningModelBuilder::Create(13) + .Inputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"InputA", TensorKind::Float, a_shape)) + .Outputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"Output", TensorKind::Float, {a_shape[0], b_shape[1]})) + .Operators().Add(Operator(L"MatMul") + .SetInput(L"A", L"InputA") + .SetConstant(L"B", TensorFloat::CreateFromArray(b_shape, std::vector(SIZET(b_shape[0] * b_shape[1]), 1))) + .SetOutput(L"Y", L"Output")) + .CreateModel(); + + LearningModelSession session(model); + LearningModelBinding binding(session); + + // Bind input + auto a_matrix = std::vector(SIZET(a_shape[0] * a_shape[1]), 1); + binding.Bind(L"InputA", TensorFloat::CreateFromArray(a_shape, a_matrix)); + + // Evaluate + auto start = std::chrono::high_resolution_clock::now(); + auto result = session.Evaluate(binding, L""); + auto end = std::chrono::high_resolution_clock::now(); + std::chrono::duration evaluate_duration_in_microseconds = end - start; + printf("Evaluate Took: %f\n", evaluate_duration_in_microseconds.count()); +#endif +} + +static void ModelBuilding_DiscreteFourierTransform() { +#if !defined(BUILD_INBOX) && defined(BUILD_MS_EXPERIMENTAL_OPS) + DiscreteFourierTransform(false /*onesided*/); + DiscreteFourierTransform(true /*onesided*/); +#endif +} + +static void ModelBuilding_DiscreteFourierTransformInverseIdentity() { +#if !defined(BUILD_INBOX) && defined(BUILD_MS_EXPERIMENTAL_OPS) + std::vector shape = {1, 5}; + std::vector output_shape = {1, shape[1], 2}; + + auto model = + LearningModelBuilder::Create(13) + .Inputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"Input.TimeSignal", TensorKind::Float, shape)) + .Outputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"Output.Spectra", TensorKind::Float, output_shape)) + .Operators().Add(Operator(L"DFT", MS_EXPERIMENTAL_DOMAIN) + .SetInput(L"input", L"Input.TimeSignal") + .SetOutput(L"output", L"DFTOutput")) + .Operators().Add(Operator(L"IDFT", MS_EXPERIMENTAL_DOMAIN) + .SetInput(L"input", L"DFTOutput") + .SetOutput(L"output", L"Output.Spectra")) + .CreateModel(); + + LearningModelSession session(model); + LearningModelBinding binding(session); + + // Populate binding + binding.Bind(L"Input.TimeSignal", TensorFloat::CreateFromArray(shape, {1, 2, 3, 4, 5})); + + // Evaluate + auto result = session.Evaluate(binding, L""); + + // Check results + printf("Output.Spectra\n"); + auto y_tensor = result.Outputs().Lookup(L"Output.Spectra").as(); + auto y_ivv = y_tensor.GetAsVectorView(); + for (int i = 0; i < output_shape[0] * output_shape[1] * 2; i += 2) { + printf("(%f + %fi), ", y_ivv.GetAt(i), y_ivv.GetAt(i + 1)); + } + printf("\n"); +#endif +} + +static void ModelBuilding_HannWindow() { +#if !defined(BUILD_INBOX) && defined(BUILD_MS_EXPERIMENTAL_OPS) + WindowFunction(L"HannWindow", TensorKind::Float); + WindowFunction(L"HannWindow", TensorKind::Double); +#endif +} + +static void ModelBuilding_HammingWindow() { +#if !defined(BUILD_INBOX) && defined(BUILD_MS_EXPERIMENTAL_OPS) + WindowFunction(L"HammingWindow", TensorKind::Float); + WindowFunction(L"HammingWindow", TensorKind::Double); +#endif +} + +static void ModelBuilding_BlackmanWindow() { +#if !defined(BUILD_INBOX) && defined(BUILD_MS_EXPERIMENTAL_OPS) + WindowFunction(L"BlackmanWindow", TensorKind::Float); + WindowFunction(L"BlackmanWindow", TensorKind::Double); +#endif +} + +static void ModelBuilding_STFT() { +#if !defined(BUILD_INBOX) && defined(BUILD_MS_EXPERIMENTAL_OPS) + size_t batch_size = 1; + size_t sample_rate = 8192; + float signal_duration_in_seconds = 5.f; + size_t signal_size = static_cast(sample_rate * signal_duration_in_seconds); + size_t dft_size = 256; + size_t hop_size = 128; + + // stft + STFT(batch_size, signal_size, dft_size, hop_size, sample_rate, true); + STFT(batch_size, signal_size, dft_size, hop_size, sample_rate, false); +#endif +} + +static void ModelBuilding_MelSpectrogramOnThreeToneSignal() { +#if !defined(BUILD_INBOX) && defined(BUILD_MS_EXPERIMENTAL_OPS) + size_t batch_size = 1; + size_t sample_rate = 8192; + float signal_duration_in_seconds = 5.f; + size_t signal_size = static_cast(sample_rate * signal_duration_in_seconds); + size_t dft_size = 256; + size_t hop_size = 128; + size_t window_size = 256; + size_t n_mel_bins = 1024; + + MelSpectrogramOnThreeToneSignal(batch_size, signal_size, dft_size, window_size, hop_size, n_mel_bins, sample_rate); +#endif +} static void SetIntraOpNumThreads() { auto shape = std::vector{1, 1000}; @@ -454,7 +988,7 @@ static void SetIntraOpNumThreads() { // Check to see that bind and evaluate continue to work when setting the intra op thread count std::vector input(1000); std::iota(std::begin(input), std::end(input), 0.0f); - auto tensor_input = TensorFloat::CreateFromShapeArrayAndDataArray(shape, input); + auto tensor_input = TensorFloat::CreateFromArray(shape, input); auto binding = LearningModelBinding(session); binding.Bind(L"input", tensor_input); WINML_EXPECT_NO_THROW(session.Evaluate(binding, L"")); @@ -485,7 +1019,19 @@ const LearningModelSessionAPITestsApi& getapi() { EvaluateSessionAndCloseModel, NamedDimensionOverride, CloseSession, - SetIntraOpNumThreads + SetIntraOpNumThreads, + ModelBuilding_Gemm, + ModelBuilding_StandardDeviationNormalization, + ModelBuilding_DynamicMatmul, + ModelBuilding_ConstantMatmul, + ModelBuilding_DiscreteFourierTransform, + ModelBuilding_DiscreteFourierTransformInverseIdentity, + ModelBuilding_HannWindow, + ModelBuilding_HammingWindow, + ModelBuilding_BlackmanWindow, + ModelBuilding_STFT, + ModelBuilding_MelSpectrogramOnThreeToneSignal, + ModelBuilding_MelWeightMatrix, }; if (SkipGpuTests()) { diff --git a/winml/test/api/LearningModelSessionAPITest.h b/winml/test/api/LearningModelSessionAPITest.h index da630527b987e..1baef2e76139c 100644 --- a/winml/test/api/LearningModelSessionAPITest.h +++ b/winml/test/api/LearningModelSessionAPITest.h @@ -21,6 +21,18 @@ struct LearningModelSessionAPITestsApi { VoidTest OverrideNamedDimension; VoidTest CloseSession; VoidTest SetIntraOpNumThreads; + VoidTest ModelBuilding_Gemm; + VoidTest ModelBuilding_StandardDeviationNormalization; + VoidTest ModelBuilding_DynamicMatmul; + VoidTest ModelBuilding_ConstantMatmul; + VoidTest ModelBuilding_DiscreteFourierTransform; + VoidTest ModelBuilding_DiscreteFourierTransformInverseIdentity; + VoidTest ModelBuilding_HannWindow; + VoidTest ModelBuilding_HammingWindow; + VoidTest ModelBuilding_BlackmanWindow; + VoidTest ModelBuilding_STFT; + VoidTest ModelBuilding_MelSpectrogramOnThreeToneSignal; + VoidTest ModelBuilding_MelWeightMatrix; }; const LearningModelSessionAPITestsApi& getapi(); @@ -43,4 +55,16 @@ WINML_TEST(LearningModelSessionAPITests, AdapterIdAndDevice) WINML_TEST(LearningModelSessionAPITests, OverrideNamedDimension) WINML_TEST(LearningModelSessionAPITests, CloseSession) WINML_TEST(LearningModelSessionAPITests, SetIntraOpNumThreads) +WINML_TEST(LearningModelSessionAPITests, ModelBuilding_Gemm) +WINML_TEST(LearningModelSessionAPITests, ModelBuilding_StandardDeviationNormalization) +WINML_TEST(LearningModelSessionAPITests, ModelBuilding_DynamicMatmul) +WINML_TEST(LearningModelSessionAPITests, ModelBuilding_ConstantMatmul) +WINML_TEST(LearningModelSessionAPITests, ModelBuilding_DiscreteFourierTransform) +WINML_TEST(LearningModelSessionAPITests, ModelBuilding_DiscreteFourierTransformInverseIdentity) +WINML_TEST(LearningModelSessionAPITests, ModelBuilding_HannWindow) +WINML_TEST(LearningModelSessionAPITests, ModelBuilding_HammingWindow) +WINML_TEST(LearningModelSessionAPITests, ModelBuilding_BlackmanWindow) +WINML_TEST(LearningModelSessionAPITests, ModelBuilding_STFT) +WINML_TEST(LearningModelSessionAPITests, ModelBuilding_MelSpectrogramOnThreeToneSignal) +WINML_TEST(LearningModelSessionAPITests, ModelBuilding_MelWeightMatrix) WINML_TEST_CLASS_END()