diff --git a/include/onnxruntime/core/framework/data_types.h b/include/onnxruntime/core/framework/data_types.h index 4e7f3c6e60343..c71376d3164b3 100644 --- a/include/onnxruntime/core/framework/data_types.h +++ b/include/onnxruntime/core/framework/data_types.h @@ -59,50 +59,10 @@ struct MLFloat16 { explicit MLFloat16(uint16_t x) : val(x) {} explicit MLFloat16(float f); - // Taken from https://stackoverflow.com/a/60047308/12627730 - float AsFloat(uint32_t x) const { - float out = 0.0f; - std::memcpy(&out, &x, sizeof(x)); - return out; - } - - // Taken from https://stackoverflow.com/a/60047308/12627730 - uint32_t AsUint(float x) const { - uint32_t out = 0; - std::memcpy(&out, &x, sizeof(x)); - return out; - } - - float HalfToFloat(const uint16_t x) const { - uint16_t half = x; - if (endian::native == endian::big) { - // Taken from https://stackoverflow.com/a/2182184/12627730 - half = (x >> 8) | (x << 8); - } - - // Taken from https://stackoverflow.com/a/60047308/12627730 - // IEEE-754 16-bit floating-point format (without infinity): 1-5-10, exp-15, +-131008.0, +-6.1035156E-5, - // +-5.9604645E-8, 3.311 digits - const uint32_t e = (half & 0x7C00) >> 10; // exponent - const uint32_t m = (half & 0x03FF) << 13; // mantissa - // evil log2 bit hack to count leading zeros in denormalized format - const uint32_t v = AsUint(static_cast(m)) >> 23; - uint32_t full = (half & 0x8000) << 16 | (e != 0) * ((e + 112) << 23 | m) | - ((e == 0) & (m != 0)) * ((v - 37) << 23 | ((m << (150 - v)) & 0x007FE000)); // sign : normalized : denormalized - - if (endian::native == endian::big) { - // Taken from https://stackoverflow.com/a/2182184/12627730 - full = ((full >> 24) & 0xff) | // move byte 3 to byte 0 - ((full << 8) & 0xff0000) | // move byte 1 to byte 2 - ((full >> 8) & 0xff00) | // move byte 2 to byte 1 - ((full << 24) & 0xff000000); // byte 0 to byte 3 - } - - return AsFloat(full); - } + float ToFloat() const; operator float() const { - return HalfToFloat(val); + return ToFloat(); } }; diff --git a/include/onnxruntime/core/platform/threadpool.h b/include/onnxruntime/core/platform/threadpool.h index 126dd133fa7ae..899a74f9c4d07 100644 --- a/include/onnxruntime/core/platform/threadpool.h +++ b/include/onnxruntime/core/platform/threadpool.h @@ -281,7 +281,7 @@ class ThreadPool { /** * Tries to call the given function in parallel, with calls split into (num_batches) batches. *\param num_batches If it is zero, it will be replaced to the value of DegreeOfParallelism(). - *\param fn A std::function or STL style functor with signature of "void f(int32_t);" + *\param fn A std::function or STL style functor with signature of "void f(std::ptrdiff_t);" * Pitfall: Caller should cap `num_batches` to a reasonable value based on the cost of `fn` and the value of `total`. *For example, if fn is as simple as: int sum=0; fn = [&](int i){sum +=i;} and `total` is 100, then num_batches should *be just 1. diff --git a/onnxruntime/core/framework/data_types.cc b/onnxruntime/core/framework/data_types.cc index ce5112f1daccb..2bc32d3cdf413 100644 --- a/onnxruntime/core/framework/data_types.cc +++ b/onnxruntime/core/framework/data_types.cc @@ -25,6 +25,10 @@ namespace onnxruntime { MLFloat16::MLFloat16(float f) : val{math::floatToHalf(f)} {} +float MLFloat16::ToFloat() const { + return math::halfToFloat(val); +} + // Return the MLDataType used for a generic Tensor template <> MLDataType DataTypeImpl::GetType() { diff --git a/onnxruntime/core/optimizer/conv_activation_fusion.cc b/onnxruntime/core/optimizer/conv_activation_fusion.cc index 1444fb9653ce8..e99a4399ea152 100644 --- a/onnxruntime/core/optimizer/conv_activation_fusion.cc +++ b/onnxruntime/core/optimizer/conv_activation_fusion.cc @@ -49,7 +49,7 @@ static bool GetClipConstantMinMax(const Graph& graph, const Node& node, float& m // value = static_cast(*i.data()); // break; case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: - value = math::halfToFloat(i.data()->val); + value = math::halfToFloat(i.data()->val); break; default: ORT_THROW("Unexpected data type for Clip input of ", initializer->data_type()); diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 191b777a40b79..eb016febc33fc 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -2,8 +2,8 @@ // Licensed under the MIT License. #include -#include -#include +#include +#include #include "boost/mp11.hpp" @@ -18,15 +18,13 @@ #include "core/providers/op_kernel_type_control.h" #include "core/util/math_cpuonly.h" +#include "Eigen/src/Core/arch/Default/BFloat16.h" #include "Eigen/src/Core/arch/Default/Half.h" #if defined(_M_AMD64) #include "core/mlas/inc/mlas.h" #endif -using namespace ONNX_NAMESPACE; -using namespace boost::mp11; - namespace onnxruntime { namespace op_kernel_type_control { @@ -56,20 +54,15 @@ using EnabledSrcTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExecu using EnabledDstTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain, Cast, Output, 0); -using IndirectCastTypes = TypeList; - -template -using IsDirectCastType = mp_not>; - -template -using AreAllDirectCastTypes = mp_all...>; - // string cast helpers +// Note: when C++17 is available, use functions -// handle floating point input separately +// handle floating point output separately template typename std::enable_if::value, void>::type CastToString(const SrcType& input, std::string& output) { + static_assert(sizeof(SrcType) <= sizeof(double), + "largest supported floating point type is double"); if (std::isnan(input)) { output = "NaN"; } else if (std::isinf(input)) { @@ -79,19 +72,49 @@ CastToString(const SrcType& input, std::string& output) { output = "INF"; } } else { - // setprecision to 8 to match numpy default behavior - std::ostringstream convert; - convert << std::setprecision(8) << input; - output = convert.str(); + // set precision to 8 to match numpy default behavior + constexpr const char* format = "%.8g"; + const double value = static_cast(input); + + char static_buffer[256]; + std::unique_ptr dynamic_buffer{}; + + gsl::span buffer_span = gsl::make_span(static_buffer); + + auto snprintf_result = std::snprintf(buffer_span.data(), buffer_span.size(), format, value); + ORT_ENFORCE(snprintf_result > 0, "snprintf() failed with return value: ", snprintf_result); + + // include trailing '\0' + const size_t required_buffer_size = gsl::narrow_cast(snprintf_result) + 1; + + if (required_buffer_size > buffer_span.size()) { + // didn't get it all, allocate a bigger buffer and retry + dynamic_buffer = onnxruntime::make_unique(required_buffer_size); + buffer_span = gsl::make_span(dynamic_buffer.get(), required_buffer_size); + snprintf_result = std::snprintf(buffer_span.data(), buffer_span.size(), format, value); + ORT_ENFORCE( + snprintf_result > 0 && + gsl::narrow_cast(snprintf_result) == buffer_span.size() - 1, + "Failed to write value with snprintf()."); + } + + output.assign(buffer_span.data(), required_buffer_size - 1); } } template typename std::enable_if::value, void>::type CastToString(const SrcType& input, std::string& output) { - std::ostringstream convert; - convert << input; - output = convert.str(); + output = std::to_string(input); +} + +// overloads for MLFloat16 and BFloat16 +void CastToString(const MLFloat16& input, std::string& output) { + CastToString(static_cast(input), output); +} + +void CastToString(const BFloat16& input, std::string& output) { + CastToString(static_cast(input), output); } template @@ -118,115 +141,121 @@ CastFromString(const std::string& input, DstType& output) { output = gsl::narrow_cast(std::stoll(input)); } -// generic scalar X -> Y -template -struct ScalarDirectCaster { - void Cast(const SrcType& in, DstType& out) const { - out = static_cast(in); - } -}; - -// scalar X -> string -template -struct ScalarDirectCaster { - void Cast(const SrcType& in, std::string& out) const { - CastToString(in, out); - } -}; +// overloads for MLFloat16 and BFloat16 +void CastFromString(const std::string& input, MLFloat16& output) { + float intermediate; + CastFromString(input, intermediate); + output = static_cast(intermediate); +} -// scalar string -> X -template -struct ScalarDirectCaster { - void Cast(const std::string& in, DstType& out) const { - CastFromString(in, out); - } -}; +void CastFromString(const std::string& input, BFloat16& output) { + float intermediate; + CastFromString(input, intermediate); + output = static_cast(intermediate); +} -// helper for indirect cast types -template -struct ScalarIndirectCaster { - void Cast(const SrcType& in, DstType& out) const { - IntermediateType intermediate; - ScalarDirectCaster{}.Cast(in, intermediate); - ScalarDirectCaster{}.Cast(intermediate, out); - } +// type that is usable with Eigen cast +template +struct EigenCastType { + using type = T; }; -template -struct ScalarCaster; +// ORT float16 types don't support casting, so map them to Eigen ones -template -struct ScalarCaster< - SrcType, DstType, - typename std::enable_if::value>::type> { - void Cast(const SrcType& in, DstType& out) const { - ScalarDirectCaster{}.Cast(in, out); - } +template <> +struct EigenCastType { + using type = Eigen::half; }; -template -struct ScalarCaster< - SrcType, DstType, - typename std::enable_if::value>::type> { - void Cast(const SrcType& in, DstType& out) const { - ScalarIndirectCaster{}.Cast(in, out); - } +template <> +struct EigenCastType { + using type = Eigen::bfloat16; }; // generic tensor X -> Y -template +template struct TensorCaster { - void Cast(const Tensor& in, Tensor& out, const TensorShape& shape) const { + void Cast(const OpKernelContext&, const Tensor& in, Tensor& out, const TensorShape& shape) const { + using SrcEigenCastType = typename EigenCastType::type; + using DstEigenCastType = typename EigenCastType::type; + const std::ptrdiff_t shape_size = gsl::narrow(shape.Size()); - const auto in_vector = ConstEigenVectorMap(in.Data(), shape_size); - auto out_vector = EigenVectorMap(out.MutableData(), shape_size); - out_vector = in_vector.unaryExpr([](const SrcType& in_scalar) { - DstType out_scalar; - ScalarCaster{}.Cast(in_scalar, out_scalar); - return out_scalar; - }); + const auto in_vector = + ConstEigenVectorMap(reinterpret_cast(in.Data()), shape_size); + auto out_vector = + EigenVectorMap(reinterpret_cast(out.MutableData()), shape_size); + out_vector = in_vector.template cast(); } }; -template -void CastStringTensor(const Tensor& in, Tensor& out, const TensorShape& shape) { - static_assert(std::is_same::value || std::is_same::value, - "Either SrcType or DstType must be std::string."); - const std::ptrdiff_t shape_size = gsl::narrow(shape.Size()); - const auto in_data = in.DataAsSpan(); - const auto out_data = out.MutableDataAsSpan(); - for (std::ptrdiff_t i = 0; i < shape_size; ++i) { - ScalarCaster{}.Cast(in_data[i], out_data[i]); - } -} - // tensor X -> string template struct TensorCaster { - void Cast(const Tensor& in, Tensor& out, const TensorShape& shape) const { - CastStringTensor(in, out, shape); + void Cast(const OpKernelContext&, const Tensor& in, Tensor& out, const TensorShape& shape) const { + const std::ptrdiff_t shape_size = gsl::narrow(shape.Size()); + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + for (std::ptrdiff_t i = 0; i < shape_size; ++i) { + CastToString(in_data[i], out_data[i]); + } } }; // tensor string -> X template struct TensorCaster { - void Cast(const Tensor& in, Tensor& out, const TensorShape& shape) const { - CastStringTensor(in, out, shape); + void Cast(const OpKernelContext&, const Tensor& in, Tensor& out, const TensorShape& shape) const { + const std::ptrdiff_t shape_size = gsl::narrow(shape.Size()); + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + for (std::ptrdiff_t i = 0; i < shape_size; ++i) { + CastFromString(in_data[i], out_data[i]); + } } }; #if defined(_M_AMD64) +// specializations to use optimized and Windows x64-specific +// MlasConvertHalfToFloatBuffer() routine for MLFloat16 -> float conversion + +template +void CastMLFloat16ThroughFloat( + const OpKernelContext& context, const Tensor& in, Tensor& out, const TensorShape& shape) { + // use optimized MLFloat16 -> float, then float -> DstType + AllocatorPtr allocator; + ORT_THROW_IF_ERROR(context.GetTempSpaceAllocator(&allocator)); + auto intermediate_buffer = IAllocator::MakeUniquePtr(allocator, gsl::narrow(shape.Size())); + Tensor intermediate_tensor{DataTypeImpl::GetType(), shape, intermediate_buffer.get(), allocator->Info()}; + TensorCaster{}.Cast(context, in, intermediate_tensor, shape); + TensorCaster{}.Cast(context, intermediate_tensor, out, shape); +} + +// tensor MLFloat16 -> X +template +struct TensorCaster { + void Cast(const OpKernelContext& context, const Tensor& in, Tensor& out, const TensorShape& shape) const { + CastMLFloat16ThroughFloat(context, in, out, shape); + } +}; + // tensor MLFloat16 -> float template <> struct TensorCaster { - void Cast(const Tensor& in, Tensor& out, const TensorShape& shape) const { + void Cast(const OpKernelContext&, const Tensor& in, Tensor& out, const TensorShape& shape) const { auto out_data = out.MutableData(); auto in_data = in.Data(); const size_t shape_size = gsl::narrow(shape.Size()); MlasConvertHalfToFloatBuffer(&in_data[0].val, out_data, shape_size); } }; + +// tensor MLFloat16 -> string +template <> +struct TensorCaster { + void Cast(const OpKernelContext& context, const Tensor& in, Tensor& out, const TensorShape& shape) const { + CastMLFloat16ThroughFloat(context, in, out, shape); + } +}; #endif class Cast final : public OpKernel { @@ -246,17 +275,18 @@ class Cast final : public OpKernel { template struct Dispatcher { - void operator()(const Tensor& src, Tensor& dst, const TensorShape& shape) { - TensorCaster{}.Cast(src, dst, shape); + void operator()(const OpKernelContext& context, const Tensor& src, Tensor& dst, const TensorShape& shape) { + TensorCaster{}.Cast(context, src, dst, shape); } }; template struct SrcDispatcher { - void operator()(int32_t to, const Tensor& src, Tensor& dst, const TensorShape& shape) { - using DstTypes = mp_remove_if_q>; + void operator()( + int32_t to, const OpKernelContext& context, const Tensor& src, Tensor& dst, const TensorShape& shape) { + using DstTypes = boost::mp11::mp_remove_if_q>; utils::MLTypeCallDispatcherFromTypeList dispatcher{to}; - dispatcher.template InvokeWithLeadingTemplateArgs>(src, dst, shape); + dispatcher.template InvokeWithLeadingTemplateArgs>(context, src, dst, shape); } }; @@ -278,7 +308,7 @@ Status Cast::Compute(OpKernelContext* context) const { } utils::MLTypeCallDispatcherFromTypeList dispatcher{from}; - dispatcher.Invoke(to_, *X, *Y, shape); + dispatcher.Invoke(to_, *context, *X, *Y, shape); return Status::OK(); } diff --git a/onnxruntime/core/providers/op_kernel_type_control.h b/onnxruntime/core/providers/op_kernel_type_control.h index 72c78165d618e..c61c0381d41fb 100644 --- a/onnxruntime/core/providers/op_kernel_type_control.h +++ b/onnxruntime/core/providers/op_kernel_type_control.h @@ -273,20 +273,25 @@ struct EnabledTypes { * * In MyProvider provider's implementation of MyOp kernel: * + * namespace onnxruntime { + * namespace op_kernel_type_control { * // specify supported types, i.e., the full set of types that can be enabled * ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES( * MyProvider, DomainContainingMyOp, MyOp, Input, 0, * int, float, double); + * } // namespace op_kernel_type_control + * } // namespace onnxruntime + * + * // ... * * // get enabled types * using MyOpFirstInputEnabledTypes = - * ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(MyProvider, DomainContainingMyOp, MyOp, Input, 0) + * ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(MyProvider, DomainContainingMyOp, MyOp, Input, 0); * - * ... + * // ... * - * // in the implementation, we can dispatch to the enabled types - * utils::MLTypeCallDispatcherFromTypeList dispatcher{firstInputRuntimeType}; - * ... + * // use MLTypeCallDispatcher to dispatch to implementations for enabled types + * using Dispatcher = onnxruntime::utils::MLTypeCallDispatcherFromTypeList; */ // all allowed type specifications should be contained in the following file diff --git a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc new file mode 100644 index 0000000000000..609f62cddfa80 --- /dev/null +++ b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc @@ -0,0 +1,188 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "boost/mp11.hpp" + +#include "gsl/gsl" + +#include "gtest/gtest.h" + +#include "core/framework/data_types_internal.h" + +#include "test/common/cuda_op_test_utils.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +template +int GetMinRequiredCudaComputeCapability() { + return 0; +} + +template <> +int GetMinRequiredCudaComputeCapability() { + return 530; +} + +template <> +int GetMinRequiredCudaComputeCapability() { + return 800; +} + +template +void TestCastOp(gsl::span input, + gsl::span output, + const std::vector& dimensions, + OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, + const std::string& expected_failure_string = "") { + OpTester test("Cast", 13); + test.AddAttribute("to", utils::ToTensorProtoElementType()); + test.AddInput("input", dimensions, input.data(), input.size()); + test.AddOutput("output", dimensions, output.data(), output.size()); + + std::unordered_set excluded_provider_types{kTensorrtExecutionProvider}; + const auto min_required_cuda_compute_capability = + std::max(GetMinRequiredCudaComputeCapability(), GetMinRequiredCudaComputeCapability()); + if (!HasCudaEnvironment(min_required_cuda_compute_capability)) { + excluded_provider_types.insert(kCudaExecutionProvider); + } + + test.Run(expect_result, expected_failure_string, excluded_provider_types); +} + +template +using RequiresCastThroughFloat = + boost::mp11::mp_any< + std::is_same, + std::is_same>; + +template +using AnyRequireCastThroughFloat = boost::mp11::mp_any...>; + +template +typename std::enable_if::value>::type +CastSpan(gsl::span src, gsl::span dst) { + std::transform( + src.begin(), src.end(), dst.begin(), + [](SrcType s) { + return static_cast(static_cast(s)); + }); +} + +template +typename std::enable_if::value>::type +CastSpan(gsl::span src, gsl::span dst) { + std::transform( + src.begin(), src.end(), dst.begin(), + [](SrcType s) { + return static_cast(s); + }); +} + +template +std::vector CastedValues(gsl::span src) { + std::vector result(src.size()); + CastSpan(src, gsl::make_span(result)); + return result; +} + +struct CastNonStringTester { + template + void operator()(const std::pair&) { + SCOPED_TRACE( + onnxruntime::MakeString( + "Cast from type ", utils::ToTensorProtoElementType(), + " to type ", utils::ToTensorProtoElementType())); + + const std::vector input_int_values{ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + const TensorShape shape{2, 3, 2, 2}; + const size_t size = gsl::narrow(shape.Size()); + ASSERT_EQ(input_int_values.size(), size); + + auto input_buffer = onnxruntime::make_unique(size); + auto input_span = gsl::make_span(input_buffer.get(), size); + CastSpan(gsl::make_span(input_int_values), input_span); + + auto output_buffer = onnxruntime::make_unique(size); + auto output_span = gsl::make_span(output_buffer.get(), size); + CastSpan(input_span, output_span); + + TestCastOp(input_span, output_span, shape.GetDims()); + } +}; + +using CastNonStringTypes = + boost::mp11::mp_list< + bool, + float, double, + uint8_t, uint16_t, uint32_t, uint64_t, + int8_t, int16_t, int32_t, int64_t, + MLFloat16, BFloat16>; + +TEST(CastOpTest, NonStringTypes) { + boost::mp11::mp_for_each>( + CastNonStringTester{}); +} + +TEST(CastOpTest, FromString) { + const std::vector shape{2, 2, 2}; + const std::vector string_data = {"-inf", "+INF", "0.9767611", "0.28280696", + "-0.12019656", "5.0", "NaN", "nan"}; + const std::vector float_output = {-(std::numeric_limits::infinity()), std::numeric_limits::infinity(), + 0.9767611f, 0.28280696f, + -0.12019656f, 5.0f, NAN, NAN}; + TestCastOp(gsl::make_span(string_data), gsl::make_span(float_output), shape); + + const std::vector float16_string_data = {"-inf", "+INF", "0.5", "0.25", + "0.0", "-1.0", "-1.5", "NaN"}; + const std::vector float16_output = + CastedValues( + gsl::make_span( + std::vector{ + -std::numeric_limits::infinity(), std::numeric_limits::infinity(), 0.5f, 0.25f, + 0.0f, -1.0f, -1.5f, NAN})); + TestCastOp(gsl::make_span(float16_string_data), gsl::make_span(float16_output), shape); + + const std::vector int_16_string_data = {"0", "1", "2", "3", "4", "5", "-32768", "32767"}; + const std::vector int_16_output = {0, 1, 2, 3, 4, 5, SHRT_MIN, SHRT_MAX}; + TestCastOp(gsl::make_span(int_16_string_data), gsl::make_span(int_16_output), shape); + + const std::vector int_64_string_data = {"0", "1", "2", "3", "4", "5", "-9223372036854775808", "9223372036854775807"}; + const std::vector int_64_output = {0, 1, 2, 3, 4, 5, LLONG_MIN, LLONG_MAX}; + TestCastOp(gsl::make_span(int_64_string_data), gsl::make_span(int_64_output), shape); +} + +TEST(CastOpTest, ToString) { + const std::vector shape{2, 2, 2}; + const std::vector float_input = {NAN, -1.f, 0.0391877927f, 0.296140194f, -0.120196559f, 5.0f, + -std::numeric_limits::infinity(), + std::numeric_limits::infinity()}; + + // float output precision is 8, so the expected output differs slightly from the input due to that + const std::vector string_output = {"NaN", "-1", "0.039187793", "0.29614019", + "-0.12019656", "5", "-INF", "INF"}; + TestCastOp(gsl::make_span(float_input), gsl::make_span(string_output), shape); + + const std::vector float16_input = + CastedValues( + gsl::make_span( + std::vector{ + -std::numeric_limits::infinity(), std::numeric_limits::infinity(), 0.5f, 0.25f, + 0.0f, -1.0f, -1.5f, NAN})); + const std::vector float16_string_output = {"-INF", "INF", "0.5", "0.25", + "0", "-1", "-1.5", "NaN"}; + TestCastOp(gsl::make_span(float16_input), gsl::make_span(float16_string_output), shape); + + const std::vector int_string_data = {"0", "1", "2", "3", "4", "5", "6", "7"}; + const std::vector int_16_input = {0, 1, 2, 3, 4, 5, 6, 7}; + TestCastOp(gsl::make_span(int_16_input), gsl::make_span(int_string_data), shape); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/tensor/tensor_op_test.cc b/onnxruntime/test/providers/cpu/tensor/tensor_op_test.cc index 0f6a75a3d675f..279d82654b192 100644 --- a/onnxruntime/test/providers/cpu/tensor/tensor_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/tensor_op_test.cc @@ -84,245 +84,6 @@ TEST(TensorOpTest, ShapeTest3D) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT: volume of dimensions is not consistent with weights size } -template -void TestCastOp(const std::initializer_list& input, - const std::initializer_list& output, - const std::vector& dimensions, - int64_t toType, - ExpectResult expect_result = ExpectResult::kExpectSuccess, - const std::string& expected_failure_string = "") { - OpTester test("Cast", 9); - test.AddAttribute("to", toType); - test.AddInput("input", dimensions, input); - test.AddOutput("output", dimensions, output); - test.Run(expect_result, expected_failure_string, {kTensorrtExecutionProvider}); -} - -template -void TestCastFromSrc() { - std::initializer_list input_data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; - const std::vector shape{3, 2, 2}; - - auto float_output = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f}; - TestCastOp(input_data, float_output, shape, TensorProto::FLOAT); - - auto double_output = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0}; - TestCastOp(input_data, double_output, shape, TensorProto::DOUBLE); - - auto bool_output = {false, true, true, true, true, true, true, true, true, true, true, true}; - TestCastOp(input_data, bool_output, shape, TensorProto::BOOL); - - const std::initializer_list uint8_t_output{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; - TestCastOp(input_data, uint8_t_output, shape, TensorProto::UINT8); - - const std::initializer_list uint16_t_output{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; - TestCastOp(input_data, uint16_t_output, shape, TensorProto::UINT16); - - const std::initializer_list uint32_t_output{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; - TestCastOp(input_data, uint32_t_output, shape, TensorProto::UINT32); - - const std::initializer_list uint64_t_output{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; - TestCastOp(input_data, uint64_t_output, shape, TensorProto::UINT64); - - const std::initializer_list int16_t_output{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; - TestCastOp(input_data, int16_t_output, shape, TensorProto::INT16); - - const std::initializer_list int32_t_output{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; - TestCastOp(input_data, int32_t_output, shape, TensorProto::INT32); - - const std::initializer_list int64_t_output{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; - TestCastOp(input_data, int64_t_output, shape, TensorProto::INT64); -}; - -TEST(TensorOpTest, Cast) { - TestCastFromSrc(); - TestCastFromSrc(); - TestCastFromSrc(); - TestCastFromSrc(); - TestCastFromSrc(); - TestCastFromSrc(); - TestCastFromSrc(); - TestCastFromSrc(); - TestCastFromSrc(); - TestCastFromSrc(); -} - -TEST(TensorOpTest, CastFromBool) { - auto bool_data = {false, true, true, true, true, true, true, true, true, true, false, true}; - const std::vector shape{3, 2, 2}; - - const std::initializer_list float_output = {0.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f, 1.0f}; - TestCastOp(bool_data, float_output, shape, TensorProto::FLOAT); - - const std::initializer_list double_output = {0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0}; - TestCastOp(bool_data, double_output, shape, TensorProto::DOUBLE); - - auto bool_output = {false, true, true, true, true, true, true, true, true, true, false, true}; - TestCastOp(bool_data, bool_output, shape, TensorProto::BOOL); - - const std::initializer_list uint8_t_output{0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1}; - TestCastOp(bool_data, uint8_t_output, shape, TensorProto::UINT8); - - const std::initializer_list uint16_t_output{0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1}; - TestCastOp(bool_data, uint16_t_output, shape, TensorProto::UINT16); - - const std::initializer_list uint32_t_output{0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1}; - TestCastOp(bool_data, uint32_t_output, shape, TensorProto::UINT32); - - const std::initializer_list uint64_t_output{0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1}; - TestCastOp(bool_data, uint64_t_output, shape, TensorProto::UINT64); - - const std::initializer_list int16_t_output{0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1}; - TestCastOp(bool_data, int16_t_output, shape, TensorProto::INT16); - - const std::initializer_list int32_t_output{0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1}; - TestCastOp(bool_data, int32_t_output, shape, TensorProto::INT32); - - const std::initializer_list int64_t_output{0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1}; - TestCastOp(bool_data, int64_t_output, shape, TensorProto::INT64); - - const std::initializer_list float16_output{ - MLFloat16(math::floatToHalf(0.0f)), - MLFloat16(math::floatToHalf(1.0f)), - MLFloat16(math::floatToHalf(1.0f)), - MLFloat16(math::floatToHalf(1.0f)), - MLFloat16(math::floatToHalf(1.0f)), - MLFloat16(math::floatToHalf(1.0f)), - MLFloat16(math::floatToHalf(1.0f)), - MLFloat16(math::floatToHalf(1.0f)), - MLFloat16(math::floatToHalf(1.0f)), - MLFloat16(math::floatToHalf(1.0f)), - MLFloat16(math::floatToHalf(0.0f)), - MLFloat16(math::floatToHalf(1.0f))}; - TestCastOp(bool_data, float16_output, shape, TensorProto::FLOAT16); -} - -TEST(TensorOpTest, CastToFloat16) { - const std::vector shape{3, 2, 2}; - std::initializer_list float_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f}; - const std::initializer_list float16_output{ - MLFloat16(math::floatToHalf(0.0f)), - MLFloat16(math::floatToHalf(1.0f)), - MLFloat16(math::floatToHalf(2.0f)), - MLFloat16(math::floatToHalf(3.0f)), - MLFloat16(math::floatToHalf(4.0f)), - MLFloat16(math::floatToHalf(5.0f)), - MLFloat16(math::floatToHalf(6.0f)), - MLFloat16(math::floatToHalf(7.0f)), - MLFloat16(math::floatToHalf(8.0f)), - MLFloat16(math::floatToHalf(9.0f)), - MLFloat16(math::floatToHalf(10.0f)), - MLFloat16(math::floatToHalf(11.0f))}; - - TestCastOp(float_data, float16_output, shape, TensorProto::FLOAT16); - - std::initializer_list uint8_t_data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; - TestCastOp(uint8_t_data, float16_output, shape, TensorProto::FLOAT16); - - std::initializer_list uint16_t_data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; - TestCastOp(uint16_t_data, float16_output, shape, TensorProto::FLOAT16); - - std::initializer_list uint32_t_data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; - TestCastOp(uint32_t_data, float16_output, shape, TensorProto::FLOAT16); - - std::initializer_list uint64_t_data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; - TestCastOp(uint64_t_data, float16_output, shape, TensorProto::FLOAT16); - - std::initializer_list int8_t_data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; - TestCastOp(int8_t_data, float16_output, shape, TensorProto::FLOAT16); - - std::initializer_list int16_t_data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; - TestCastOp(int16_t_data, float16_output, shape, TensorProto::FLOAT16); - - std::initializer_list int32_t_data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; - TestCastOp(int32_t_data, float16_output, shape, TensorProto::FLOAT16); - - std::initializer_list int64_t_data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; - TestCastOp(int64_t_data, float16_output, shape, TensorProto::FLOAT16); -} - -TEST(TensorOpTest, CastFromFloat16) { - const std::vector shape{3, 2, 2}; - const std::initializer_list float_output = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f}; - const std::initializer_list input = { - MLFloat16(math::floatToHalf(0.0f)), - MLFloat16(math::floatToHalf(1.0f)), - MLFloat16(math::floatToHalf(2.0f)), - MLFloat16(math::floatToHalf(3.0f)), - MLFloat16(math::floatToHalf(4.0f)), - MLFloat16(math::floatToHalf(5.0f)), - MLFloat16(math::floatToHalf(6.0f)), - MLFloat16(math::floatToHalf(7.0f)), - MLFloat16(math::floatToHalf(8.0f)), - MLFloat16(math::floatToHalf(9.0f)), - MLFloat16(math::floatToHalf(10.0f)), - MLFloat16(math::floatToHalf(11.0f))}; - - TestCastOp(input, float_output, shape, TensorProto::FLOAT); - - auto bool_data = {false, true, true, true, true, true, true, true, true, true, true, true}; - TestCastOp(input, bool_data, shape, TensorProto::BOOL); - - std::initializer_list uint8_t_data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; - TestCastOp(input, uint8_t_data, shape, TensorProto::UINT8); - - std::initializer_list uint16_t_data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; - TestCastOp(input, uint16_t_data, shape, TensorProto::UINT16); - - std::initializer_list uint32_t_data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; - TestCastOp(input, uint32_t_data, shape, TensorProto::UINT32); - - std::initializer_list uint64_t_data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; - TestCastOp(input, uint64_t_data, shape, TensorProto::UINT64); - - std::initializer_list int8_t_data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; - TestCastOp(input, int8_t_data, shape, TensorProto::INT8); - - std::initializer_list int16_t_data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; - TestCastOp(input, int16_t_data, shape, TensorProto::INT16); - - std::initializer_list int32_t_data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; - TestCastOp(input, int32_t_data, shape, TensorProto::INT32); - - std::initializer_list int64_t_data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; - TestCastOp(input, int64_t_data, shape, TensorProto::INT64); -} - -TEST(TensorOpTest, CastFromString) { - const std::vector shape{2, 2, 2}; - std::initializer_list string_data = {"-inf", "+INF", "0.9767611", "0.28280696", - "-0.12019656", "5.0", "NaN", "nan"}; - const std::initializer_list float_output = {-(std::numeric_limits::infinity()), std::numeric_limits::infinity(), - 0.9767611f, 0.28280696f, - -0.12019656f, 5.0f, NAN, NAN}; - TestCastOp(string_data, float_output, shape, TensorProto::FLOAT); - - std::initializer_list int_16_string_data = {"0", "1", "2", "3", "4", "5", "-32768", "32767"}; - const std::initializer_list int_16_output = {0, 1, 2, 3, 4, 5, SHRT_MIN, SHRT_MAX}; - TestCastOp(int_16_string_data, int_16_output, shape, TensorProto::INT16); - - std::initializer_list int_64_string_data = {"0", "1", "2", "3", "4", "5", "-9223372036854775808", "9223372036854775807"}; - const std::initializer_list int_64_output = {0, 1, 2, 3, 4, 5, LLONG_MIN, LLONG_MAX}; - TestCastOp(int_64_string_data, int_64_output, shape, TensorProto::INT64); -} - -TEST(TensorOpTest, CastToString) { - const std::vector shape{2, 2, 2}; - const std::initializer_list float_input = {NAN, -1.f, 0.0391877927f, 0.296140194f, -0.120196559f, 5.0f, - -std::numeric_limits::infinity(), - std::numeric_limits::infinity()}; - - // float output precision is 8, so the expected output differs slightly from the input due to that - std::initializer_list string_output = {"NaN", "-1", "0.039187793", "0.29614019", - "-0.12019656", "5", "-INF", "INF"}; - TestCastOp(float_input, string_output, shape, TensorProto::STRING); - - std::initializer_list int_string_data = {"0", "1", "2", "3", "4", "5", "6", "7"}; - const std::initializer_list int_16_input = {0, 1, 2, 3, 4, 5, 6, 7}; - TestCastOp(int_16_input, int_string_data, shape, TensorProto::STRING); -} - void MeanVarianceNormalizationFunctionDefaultPerChannel() { const int64_t N = 2, C = 2, H = 2, W = 3; diff --git a/onnxruntime/test/providers/provider_test_utils.cc b/onnxruntime/test/providers/provider_test_utils.cc index 506661a83b9d8..38d3eb4163534 100644 --- a/onnxruntime/test/providers/provider_test_utils.cc +++ b/onnxruntime/test/providers/provider_test_utils.cc @@ -251,9 +251,11 @@ void Check(const OpTester::Data& expected_data, threshold = 0.005f; #endif for (int i = 0; i < size; ++i) { - if (std::isinf(f_expected[i])) // Test infinity for equality - EXPECT_EQ(f_expected[i], f_output[i]) << "i:" << i; - else { + if (std::isnan(f_expected[i])) { + EXPECT_TRUE(std::isnan(f_expected[i])) << "Expected NaN. i:" << i << ", provider_type: " << provider_type; + } else if (std::isinf(f_expected[i])) { // Test infinity for equality + EXPECT_EQ(f_expected[i], f_output[i]) << "Expected infinity. i:" << i << ", provider_type: " << provider_type; + } else { // the default for existing tests EXPECT_NEAR(f_expected[i], f_output[i], threshold) << "i:" << i << ", provider_type: " << provider_type; @@ -284,9 +286,11 @@ void Check(const OpTester::Data& expected_data, /// XXX: May need to adjust threshold as BFloat is coarse float threshold = 0.001f; for (int i = 0; i < size; ++i) { - if (std::isinf(f_expected[i])) // Test infinity for equality - EXPECT_EQ(f_expected[i], f_output[i]); - else { + if (std::isnan(f_expected[i])) { + EXPECT_TRUE(std::isnan(f_expected[i])) << "Expected NaN. i:" << i << ", provider_type: " << provider_type; + } else if (std::isinf(f_expected[i])) { // Test infinity for equality + EXPECT_EQ(f_expected[i], f_output[i]) << "Expected infinity. i:" << i << ", provider_type: " << provider_type; + } else { // the default for existing tests const float max_value = fmax(fabs(f_expected[i]), fabs(f_output[i])); if (max_value != 0) { // max_value = 0 means output and expected are 0s.