diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index fa6c731231405..8659c96b540c8 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -58,11 +58,11 @@ Do not modify directly.* |BitwiseOr|*in* A:**T**
*in* B:**T**
*out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |BitwiseXor|*in* A:**T**
*in* B:**T**
*out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |BlackmanWindow|*in* size:**T1**
*out* output:**T2**|17+|**T1** = tensor(int32), tensor(int64)
**T2** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Cast|*in* input:**T1**
*out* output:**T2**|23+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||[21, 22]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||[19, 20]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||[13, 18]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||[6, 12]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Cast|*in* input:**T1**
*out* output:**T2**|23+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| +|||[21, 22]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| +|||[19, 20]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| +|||[13, 18]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| +|||[6, 12]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| |Ceil|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float)| |||[6, 12]|**T** = tensor(double), tensor(float)| |Celu|*in* X:**T**
*out* Y:**T**|12+|**T** = tensor(float)| diff --git a/include/onnxruntime/core/framework/data_types_internal.h b/include/onnxruntime/core/framework/data_types_internal.h index 05f4c10995ef2..4cc57ba4b5391 100644 --- a/include/onnxruntime/core/framework/data_types_internal.h +++ b/include/onnxruntime/core/framework/data_types_internal.h @@ -319,6 +319,10 @@ class CallableDispatchableHelper { public: explicit CallableDispatchableHelper(int32_t dt_type) noexcept : dt_type_(dt_type), called_(0) {} +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4702) +#endif // Must return integer to be in a expandable context template int Invoke(Fn&& fn, Args&&... args) { @@ -328,6 +332,9 @@ class CallableDispatchableHelper { } return 0; } +#if defined(_MSC_VER) +#pragma warning(pop) +#endif void CheckCalledOnce() const { ORT_ENFORCE(called_ == 1, "Unsupported data type: ", dt_type_); @@ -338,7 +345,7 @@ class CallableDispatchableHelper { // Other policies may set the second result argument accordingly. template struct UnsupportedTypeDefaultPolicy { - void operator()(int32_t dt_type, Ret& /*result*/) const { + [[noreturn]] void operator()(int32_t dt_type, Ret& /*result*/) const { ORT_THROW("Unsupported data type: ", dt_type); } }; diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index d1c280d9886f4..685937049e58f 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -31,7 +31,7 @@ namespace op_kernel_type_control { // we're using one set of types for all opsets of Cast ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS( kCpuExecutionProvider, kOnnxDomain, Cast, Input, 0, - element_type_lists::AllIRv9); + element_type_lists::AllIRv10); ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPES_ALL_OPSETS( kCpuExecutionProvider, kOnnxDomain, Cast, Input, 0, @@ -39,7 +39,7 @@ ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPES_ALL_OPSETS( ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS( kCpuExecutionProvider, kOnnxDomain, Cast, Output, 0, - element_type_lists::AllIRv9); + element_type_lists::AllIRv10); ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPES_ALL_OPSETS( kCpuExecutionProvider, kOnnxDomain, Cast, Output, 0, @@ -58,8 +58,43 @@ using IsOrtFloat16Type = boost::mp11::mp_contains, #if !defined(DISABLE_FLOAT8_TYPES) template using IsOrtFloat8Type = boost::mp11::mp_contains; +#else +template +struct IsOrtFloat8Type : std::false_type {}; #endif +template +using IsOrtInt4Type = boost::mp11::mp_contains, T>; + +template +struct IsStandardIntegerType { + static constexpr bool value = + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v; +}; + +// Types that Int4x2 and UInt4x2 convert to and from, apart from string. +template +struct IsOrtInt4NumericConversionType { + static constexpr bool value = + std::is_same_v || + IsStandardIntegerType::value || + std::is_floating_point_v || + IsOrtFloat16Type::value || + IsOrtFloat8Type::value; +}; + +template +struct IsOrtInt4ConversionType { + static constexpr bool value = IsOrtInt4NumericConversionType::value || std::is_same_v; +}; + // string cast helpers // Note: when C++17 is available, use functions @@ -115,11 +150,7 @@ CastToString(const SrcType& input, std::string& output) { } template -#if !defined(DISABLE_FLOAT8_TYPES) typename std::enable_if::value || IsOrtFloat8Type::value, void>::type -#else -typename std::enable_if::value>::type -#endif CastToString(const SrcType& input, std::string& output) { CastToString(static_cast(input), output); } @@ -149,11 +180,7 @@ CastFromString(const std::string& input, DstType& output) { } template -#if !defined(DISABLE_FLOAT8_TYPES) typename std::enable_if::value || IsOrtFloat8Type::value, void>::type -#else -typename std::enable_if::value, void>::type -#endif CastFromString(const std::string& input, DstType& output) { float intermediate; CastFromString(input, intermediate); @@ -177,6 +204,134 @@ template <> struct EigenCastType { using type = Eigen::bfloat16; }; + +// Helper for converting (U)Int4x2 values to any destination type. +template ::value && IsOrtInt4ConversionType::value>> +struct FromInt4Converter { + // The input 'val' can be either an int8_t value coming from Int4x2.GetElem(pos), + // or an uint8_t value coming from UInt4x2.GetElem(pos), where pos can be 0 or 1. + static DstType Convert(typename SrcType::UnpackedType val) { + if constexpr (IsOrtFloat16Type::value) { + return DstType(static_cast(val)); + } else if constexpr (IsOrtFloat8Type::value) { + return DstType(static_cast(val), true); + } else if constexpr (std::is_same_v) { + return val != 0; + } else if constexpr (std::is_same_v) { + // val has type (u)int8_t, so static_cast is required in order for std::to_string + // to interpret val as a number (65 -> "65"), instead of a char (65 -> "A"). + return std::to_string(static_cast(val)); + } else { + return static_cast(val); + } + } +}; + +// Helper for converting any source type to (U)Int4x2::UnpackedType values (int8_t and uint8_t). +template ::value && IsOrtInt4Type::value>> +struct ToInt4Converter { + static typename DstType::UnpackedType Convert(const SrcType& val); +}; + +// See https://onnx.ai/onnx/operators/onnx__Cast.html#summary for casting from +// fixed point to fixed point: when OOR, discard higher bits and reinterpret +// (with respect to two's complement representation for signed types). +// The following example is listed: 200 (int16) converts to -56 (int8). +// For our int4 conversion, 200 (int16) would convert to -8 (int4). +template +struct ToInt4Converter::value>> { + static int8_t Convert(const SrcType& val) { + // Example: int8_t(14) converts to int4 (-2) + // int8_t(14) is 0000_1110 + // truncate: 0000_1110 & 0000_1111 = 0000_1110 + // in 4-bit two's complement, 1110 = 1 * -8 + 1 * 4 + 1 * 2 + 1 * 0 = -2 + // sign-extend: -2 in int8_t is 1111_0000 | 0000_1110 = 1111_1110 + + // Truncate to 4 least significant bits + uint8_t truncated = static_cast(val) & 0x0F; + + // Sign-extend: if bit 3 is set, it's negative in 4-bit two's complement, + // so set the 4 most significant bits to 1. + return static_cast((truncated & 0x8) ? (truncated | 0xF0) : truncated); + } +}; + +// See https://onnx.ai/onnx/operators/onnx__Cast.html#summary for casting from +// fixed point to fixed point: when OOR, discard higher bits and reinterpret +// (with respect to two's complement representation for signed types). +template +struct ToInt4Converter::value>> { + static uint8_t Convert(const SrcType& val) { + // Truncate to 4 least significant bits + return static_cast(val) & 0x0F; + } +}; + +// bool -> (U)Int4x2 +template +struct ToInt4Converter::value>> { + static typename DstType::UnpackedType Convert(bool val) { + return static_cast(val ? 1 : 0); + } +}; + +// float -> (U)Int4x2 +// Per https://onnx.ai/onnx/operators/onnx__Cast.html#summary, casting from +// floating point to fixed point is undefined if OOR. +template +struct ToInt4Converter::value>> { + static typename DstType::UnpackedType Convert(const float& val) { + int result = static_cast(std::roundf(val)); + return ToInt4Converter::Convert(result); + } +}; + +// double -> (U)Int4x2 +template +struct ToInt4Converter::value>> { + static typename DstType::UnpackedType Convert(const double& val) { + int result = static_cast(std::round(val)); + return ToInt4Converter::Convert(result); + } +}; + +// float 8 -> (U)Int4x2 +template +struct ToInt4Converter::value && IsOrtInt4Type::value>> { + static typename DstType::UnpackedType Convert(const SrcType& val) { + float result = val.ToFloat(); + return ToInt4Converter::Convert(result); + } +}; + +// float 16 -> (U)Int4x2 +template +struct ToInt4Converter::value && IsOrtInt4Type::value>> { + static typename DstType::UnpackedType Convert(const SrcType& val) { + float f_val = static_cast(val); + return ToInt4Converter::Convert(f_val); + } +}; + +// string -> (U)Int4x2 +template +struct ToInt4Converter::value>> { + static typename DstType::UnpackedType Convert(const std::string& val) { + double result = std::stod(val); + return ToInt4Converter::Convert(result); + } +}; + // generic tensor X -> Y template struct TensorCaster { @@ -193,9 +348,10 @@ struct TensorCaster { } }; -// tensor X -> string +// tensor X -> string, if X != (U)Int4x2 template -struct TensorCaster { +struct TensorCaster::value>> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const std::ptrdiff_t shape_size = narrow(shape.Size()); const auto* in_data = in.Data(); @@ -206,9 +362,10 @@ struct TensorCaster { } }; -// tensor string -> X +// tensor string -> X, if X != (U)Int4x2 template -struct TensorCaster { +struct TensorCaster::value>> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const std::ptrdiff_t shape_size = narrow(shape.Size()); const auto* in_data = in.Data(); @@ -219,46 +376,105 @@ struct TensorCaster { } }; -#if !defined(DISABLE_FLOAT8_TYPES) +// tensor MLFloat16 -> float +template <> +struct TensorCaster { + void Cast(const OpKernelContext& ctx, const TensorShape& shape, const Tensor& in, Tensor& out) const { + auto out_data = out.MutableData(); + auto in_data = in.Data(); + const size_t shape_size = narrow(shape.Size()); + MlasConvertHalfToFloatBufferInParallel(in_data, out_data, shape_size, ctx.GetOperatorThreadPool()); + } +}; -// tensor X -> float 8 -template -struct TensorCasterNoSat { +// tensor float -> MLFloat16 +template <> +struct TensorCaster { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const std::ptrdiff_t shape_size = narrow(shape.Size()); + auto in_data = in.Data(); + auto out_data = out.MutableData(); + const size_t shape_size = narrow(shape.Size()); + MlasConvertFloatToHalfBuffer(in_data, out_data, shape_size); + } +}; + +// (U)Int4x2 -> string or numeric types +template +struct TensorCaster::value && IsOrtInt4ConversionType::value>> { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const ptrdiff_t shape_size = narrow(shape.Size()); const auto* in_data = in.Data(); auto* out_data = out.MutableData(); - for (std::ptrdiff_t i = 0; i < shape_size; ++i) { - out_data[i] = DstType(static_cast(in_data[i]), false); + + for (ptrdiff_t i = 0; i < shape_size; ++i) { + // elem 0 is the low nibble, 1 the high nibble + auto val = in_data[i >> 1].GetElem(i & 0x1); + out_data[i] = FromInt4Converter::Convert(val); } } }; -// tensor string -> float 8 -template -struct TensorCasterNoSat { +// string or numeric types -> (U)Int4x2 +template +struct TensorCaster::value && IsOrtInt4Type::value>> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const std::ptrdiff_t shape_size = narrow(shape.Size()); - const auto* in_data = in.Data(); + const ptrdiff_t shape_size = narrow(shape.Size()); + const auto* in_data = in.Data(); auto* out_data = out.MutableData(); - float float_value; - for (std::ptrdiff_t i = 0; i < shape_size; ++i) { - CastFromString(in_data[i], float_value); - out_data[i] = DstType(float_value, false); + + ptrdiff_t i = 0; + for (; i < shape_size - 1; i += 2) { + auto low_val = ToInt4Converter::Convert(in_data[i]); + auto high_val = ToInt4Converter::Convert(in_data[i + 1]); + out_data[i >> 1] = DstType(low_val, high_val); + } + + if (i < shape_size) { + auto low_val = ToInt4Converter::Convert(in_data[i]); + out_data[i >> 1] = DstType(low_val, 0); } } }; -#endif +// Int4x2 -> UInt4x2 +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const ptrdiff_t shape_size = narrow(shape.Size() + 1) >> 1; + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); -// tensor MLFloat16 -> float + for (ptrdiff_t i = 0; i < shape_size; ++i) { + auto low_nibble = in_data[i].GetElem(0); + auto high_nibble = in_data[i].GetElem(1); + + uint8_t low_unsigned = static_cast(low_nibble) & 0x0F; + uint8_t high_unsigned = static_cast(high_nibble) & 0x0F; + + out_data[i] = UInt4x2(low_unsigned, high_unsigned); + } + } +}; + +// UInt4x2 -> Int4x2 template <> -struct TensorCaster { - void Cast(const OpKernelContext& ctx, const TensorShape& shape, const Tensor& in, Tensor& out) const { - auto out_data = out.MutableData(); - auto in_data = in.Data(); - const size_t shape_size = narrow(shape.Size()); - MlasConvertHalfToFloatBufferInParallel(in_data, out_data, shape_size, ctx.GetOperatorThreadPool()); +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const ptrdiff_t shape_size = narrow(shape.Size() + 1) >> 1; + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + for (ptrdiff_t i = 0; i < shape_size; ++i) { + auto low_nibble = in_data[i].GetElem(0); + auto high_nibble = in_data[i].GetElem(1); + + int8_t low_signed = static_cast((low_nibble & 0x0F) << 4) >> 4; + int8_t high_signed = static_cast((high_nibble & 0x0F) << 4) >> 4; + + out_data[i] = Int4x2(low_signed, high_signed); + } } }; @@ -284,7 +500,8 @@ void CastMLFloat16ThroughFloatTensor( // tensor MLFloat16 -> X template -struct TensorCaster { +struct TensorCaster::value>> { void Cast(const OpKernelContext& context, const TensorShape& shape, const Tensor& in, Tensor& out) const { CastMLFloat16ThroughFloatTensor(context, shape, in, out); } @@ -299,6 +516,54 @@ struct TensorCaster { }; #endif +#if !defined(DISABLE_FLOAT8_TYPES) +// TensorCasterNoSat is only called when all the below conditions are met (see Cast::Compute): +// - defined(DISABLE_FLOAT8_TYPES) == false +// - saturate_ == false +// - IsOrtFloat8Type::value == true + +// tensor X -> float 8 +template ::value>> +struct TensorCasterNoSat { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const std::ptrdiff_t shape_size = narrow(shape.Size()); + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + for (std::ptrdiff_t i = 0; i < shape_size; ++i) { + out_data[i] = DstType(static_cast(in_data[i]), false); + } + } +}; + +// tensor (U)Int4x2 -> float 8 +template +struct TensorCasterNoSat::value && IsOrtFloat8Type::value>> { + void Cast(const OpKernelContext& context, const TensorShape& shape, const Tensor& src, Tensor& dst) const { + // Int4x2/UInt4x2 always fit inside any Float8 type, so we can reuse the saturate = true implementation. + TensorCaster{}.Cast(context, shape, src, dst); + } +}; + +// tensor string -> float 8 +template +struct TensorCasterNoSat::value>> { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const std::ptrdiff_t shape_size = narrow(shape.Size()); + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + float float_value; + for (std::ptrdiff_t i = 0; i < shape_size; ++i) { + CastFromString(in_data[i], float_value); + out_data[i] = DstType(float_value, false); + } + } +}; + +#endif + class Cast final : public OpKernel { public: Cast(const OpKernelInfo& info) : OpKernel(info) { diff --git a/onnxruntime/test/onnx/TestCase.cc b/onnxruntime/test/onnx/TestCase.cc index a74ecacc1f26e..647c947f37f0c 100644 --- a/onnxruntime/test/onnx/TestCase.cc +++ b/onnxruntime/test/onnx/TestCase.cc @@ -948,6 +948,16 @@ std::unique_ptr> GetBrokenTests(const std::string& provider {"slice_neg_steps", "Type parameter (Tind) bound to different types (tensor(int64) and tensor(int32) in node ()."}, {"cast_BFLOAT16_to_FLOAT", "Unexpected input data type"}, + {"cast_FLOAT16_to_INT4", "Skipped until onnxruntime/cmake/external/onnx points to onnx 1.19 which should include @onnx/onnx/pull/7074"}, + {"cast_FLOAT16_to_UINT4", "Skipped until onnxruntime/cmake/external/onnx points to onnx 1.19 which should include @onnx/onnx/pull/7074"}, + {"cast_FLOAT_to_INT4", "Skipped until onnxruntime/cmake/external/onnx points to onnx 1.19 which should include @onnx/onnx/pull/7074"}, + {"cast_FLOAT_to_UINT4", "Skipped until onnxruntime/cmake/external/onnx points to onnx 1.19 which should include @onnx/onnx/pull/7074"}, + {"cast_INT4_to_FLOAT", "Skipped until onnxruntime/cmake/external/onnx points to onnx 1.19 which should include @onnx/onnx/pull/7074"}, + {"cast_INT4_to_FLOAT16", "Skipped until onnxruntime/cmake/external/onnx points to onnx 1.19 which should include @onnx/onnx/pull/7074"}, + {"cast_INT4_to_INT8", "Skipped until onnxruntime/cmake/external/onnx points to onnx 1.19 which should include @onnx/onnx/pull/7074"}, + {"cast_UINT4_to_FLOAT", "Skipped until onnxruntime/cmake/external/onnx points to onnx 1.19 which should include @onnx/onnx/pull/7074"}, + {"cast_UINT4_to_FLOAT16", "Skipped until onnxruntime/cmake/external/onnx points to onnx 1.19 which should include @onnx/onnx/pull/7074"}, + {"cast_UINT4_to_UINT8", "Skipped until onnxruntime/cmake/external/onnx points to onnx 1.19 which should include @onnx/onnx/pull/7074"}, {"loop13_seq", "Creation of empty sequences is currently not supported in the test runner"}, {"sequence_insert_at_front", "shape mismatch, expect {4} got {3}"}, {"cast_FLOAT_to_BFLOAT16", "expect uint16 got bfloat16"}, diff --git a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc index 384adb5916cc1..68d4f3559504a 100644 --- a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #include @@ -57,7 +57,7 @@ void TestCastOp(gsl::span input, const BaseTester::DimsVariant& dimensions, OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, const std::string& expected_failure_string = "", - int opset = 13, + int opset = 21, Saturate saturate = Saturate::None) { OpTester test("Cast", opset); test.AddAttribute("to", utils::ToTensorProtoElementType()); @@ -207,6 +207,1036 @@ TEST(CastOpTest, ToString) { TestCastOp(gsl::make_span(int_16_input), gsl::make_span(int_string_data), shape); } +TEST(CastOpTest, Int4x2ToInt8) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int4x2_input = { + Int4x2(-8, 7), // boundary values + Int4x2(0, -1), // zero and negative + Int4x2(3, -5), // positive and negative + Int4x2(6, 2) // both positive + }; + + const std::vector expected_int8_output = {-8, 7, 0, -1, 3, -5, 6, 2}; + + // WHEN, THEN + TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_int8_output), shape); +} + +TEST(CastOpTest, Int4x2ToUInt8) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int4x2_input = { + Int4x2(-8, 7), // boundary values + Int4x2(0, -1), // zero and negative + Int4x2(3, -5), // positive and negative + Int4x2(6, 2) // both positive + }; + + // Negative values will be cast to their unsigned representation + const std::vector expected_uint8_output = {248, 7, 0, UINT8_MAX, 3, 251, 6, 2}; + + // WHEN, THEN + TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_uint8_output), shape); +} + +TEST(CastOpTest, Int4x2ToInt16) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int4x2_input = { + Int4x2(-8, 7), // boundary values + Int4x2(0, -1), // zero and negative + Int4x2(3, -5), // positive and negative + Int4x2(6, 2) // both positive + }; + + const std::vector expected_int16_output = {-8, 7, 0, -1, 3, -5, 6, 2}; + + // WHEN, THEN + TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_int16_output), shape); +} + +TEST(CastOpTest, Int4x2ToUInt16) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int4x2_input = { + Int4x2(-8, 7), // boundary values + Int4x2(0, -1), // zero and negative + Int4x2(3, -5), // positive and negative + Int4x2(6, 2) // both positive + }; + + // Negative values will be cast to their unsigned representation + const std::vector expected_uint16_output = {65528, 7, 0, UINT16_MAX, 3, 65531, 6, 2}; + + // WHEN, THEN + TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_uint16_output), shape); +} + +TEST(CastOpTest, Int4x2ToInt32) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int4x2_input = { + Int4x2(-8, 7), // boundary values + Int4x2(0, -1), // zero and negative + Int4x2(3, -5), // positive and negative + Int4x2(6, 2) // both positive + }; + + const std::vector expected_int32_output = {-8, 7, 0, -1, 3, -5, 6, 2}; + + // WHEN, THEN + TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_int32_output), shape); +} + +TEST(CastOpTest, Int4x2ToUInt32) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int4x2_input = { + Int4x2(-8, 7), // boundary values + Int4x2(0, -1), // zero and negative + Int4x2(3, -5), // positive and negative + Int4x2(6, 2) // both positive + }; + + // Negative values will be cast to their unsigned representation + const std::vector expected_uint32_output = {4294967288, 7, 0, UINT32_MAX, 3, 4294967291, 6, 2}; + + // WHEN, THEN + TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_uint32_output), shape); +} + +TEST(CastOpTest, Int4x2ToInt32OddNumberOfElements) { + // GIVEN + const std::vector odd_shape{5}; + const std::vector odd_input = { + Int4x2(-8, 7), // boundary values + Int4x2(0, -1), // zero and negative + Int4x2(3, 0), + }; + + const std::vector expected_odd_output = {-8, 7, 0, -1, 3}; + + // WHEN, THEN + TestCastOp(gsl::make_span(odd_input), gsl::make_span(expected_odd_output), odd_shape); +} + +TEST(CastOpTest, Int4x2ToInt64) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int4x2_input = { + Int4x2(-8, 7), // boundary values + Int4x2(0, -1), // zero and negative + Int4x2(3, -5), // positive and negative + Int4x2(6, 2) // both positive + }; + + const std::vector expected_int64_output = {-8, 7, 0, -1, 3, -5, 6, 2}; + + // WHEN, THEN + TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_int64_output), shape); +} + +TEST(CastOpTest, Int4x2ToUInt64) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int4x2_input = { + Int4x2(-8, 7), // boundary values + Int4x2(0, -1), // zero and negative + Int4x2(3, -5), // positive and negative + Int4x2(6, 2) // both positive + }; + + // Negative values will be cast to their unsigned representation + const std::vector expected_uint64_output = {18446744073709551608ULL, 7, 0, UINT64_MAX, 3, 18446744073709551611ULL, 6, 2}; + + // WHEN, THEN + TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_uint64_output), shape); +} + +TEST(CastOpTest, UInt4x2ToUInt8) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint4x2_input = { + UInt4x2(0, 15), // boundary values + UInt4x2(1, 14), // small and large + UInt4x2(7, 8), // middle values + UInt4x2(3, 12) // mixed values + }; + + const std::vector expected_uint8_output = {0, 15, 1, 14, 7, 8, 3, 12}; + + // WHEN, THEN + TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_uint8_output), shape); +} + +TEST(CastOpTest, UInt4x2ToInt8) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint4x2_input = { + UInt4x2(0, 15), // boundary values + UInt4x2(1, 14), // small and large + UInt4x2(7, 8), // middle values + UInt4x2(3, 12) // mixed values + }; + + const std::vector expected_int8_output = {0, 15, 1, 14, 7, 8, 3, 12}; + + // WHEN, THEN + TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_int8_output), shape); +} + +TEST(CastOpTest, UInt4x2ToUInt16) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint4x2_input = { + UInt4x2(0, 15), // boundary values + UInt4x2(1, 14), // small and large + UInt4x2(7, 8), // middle values + UInt4x2(3, 12) // mixed values + }; + + const std::vector expected_uint16_output = {0, 15, 1, 14, 7, 8, 3, 12}; + + // WHEN, THEN + TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_uint16_output), shape); +} + +TEST(CastOpTest, UInt4x2ToInt16) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint4x2_input = { + UInt4x2(0, 15), // boundary values + UInt4x2(1, 14), // small and large + UInt4x2(7, 8), // middle values + UInt4x2(3, 12) // mixed values + }; + + const std::vector expected_int16_output = {0, 15, 1, 14, 7, 8, 3, 12}; + + // WHEN, THEN + TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_int16_output), shape); +} + +TEST(CastOpTest, UInt4x2ToUInt32) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint4x2_input = { + UInt4x2(0, 15), // boundary values + UInt4x2(1, 14), // small and large + UInt4x2(7, 8), // middle values + UInt4x2(3, 12) // mixed values + }; + + const std::vector expected_uint32_output = {0, 15, 1, 14, 7, 8, 3, 12}; + + // WHEN, THEN + TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_uint32_output), shape); +} + +TEST(CastOpTest, UInt4x2ToInt32) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint4x2_input = { + UInt4x2(0, 15), // boundary values + UInt4x2(1, 14), // small and large + UInt4x2(7, 8), // middle values + UInt4x2(3, 12) // mixed values + }; + + const std::vector expected_int32_output = {0, 15, 1, 14, 7, 8, 3, 12}; + + // WHEN, THEN + TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_int32_output), shape); +} + +TEST(CastOpTest, UInt4x2ToUInt64) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint4x2_input = { + UInt4x2(0, 15), // boundary values + UInt4x2(1, 14), // small and large + UInt4x2(7, 8), // middle values + UInt4x2(3, 12) // mixed values + }; + + const std::vector expected_uint64_output = {0, 15, 1, 14, 7, 8, 3, 12}; + + // WHEN, THEN + TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_uint64_output), shape); +} + +TEST(CastOpTest, UInt4x2ToInt64) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint4x2_input = { + UInt4x2(0, 15), // boundary values + UInt4x2(1, 14), // small and large + UInt4x2(7, 8), // middle values + UInt4x2(3, 12) // mixed values + }; + + const std::vector expected_int64_output = {0, 15, 1, 14, 7, 8, 3, 12}; + + // WHEN, THEN + TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_int64_output), shape); +} + +TEST(CastOpTest, Int4x2ToBool) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int4x2_input = { + Int4x2(0, -1), // zero and non-zero + Int4x2(7, 0), // non-zero and zero + Int4x2(-8, 3), // both non-zero + Int4x2(0, 0) // both zero + }; + + const bool bool_output[] = {false, true, true, false, true, true, false, false}; + const gsl::span expected_bool_output_span(bool_output); + + // WHEN, THEN + TestCastOp(gsl::make_span(int4x2_input), expected_bool_output_span, shape); +} + +TEST(CastOpTest, UInt4x2ToBool) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint4x2_input = { + UInt4x2(0, 1), // zero and non-zero + UInt4x2(15, 0), // non-zero and zero + UInt4x2(8, 7), // both non-zero + UInt4x2(0, 0) // both zero + }; + + const bool bool_output[] = {false, true, true, false, true, true, false, false}; + const gsl::span expected_bool_output_span(bool_output); + + // WHEN, THEN + TestCastOp(gsl::make_span(uint4x2_input), expected_bool_output_span, shape); +} + +TEST(CastOpTest, Int4x2ToFloat) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int4x2_input = { + Int4x2(1, 2), // two 4-bit int elements: lower = 1, upper = 2 + Int4x2(-3, -4), + Int4x2(5, -6), + Int4x2(-8, 7)}; + + const std::vector expected_float_output = {1.0f, 2.0f, -3.0f, -4.0f, 5.0f, -6.0f, -8.0f, 7.0f}; + + // WHEN, THEN + TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_float_output), shape); +} + +TEST(CastOpTest, UInt4x2ToFloat) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint4x2_input = { + UInt4x2(0, 1), + UInt4x2(2, 3), + UInt4x2(7, 8), + UInt4x2(14, 15)}; + + const std::vector expected_float_output = {0.0f, 1.0f, 2.0f, 3.0f, 7.0f, 8.0f, 14.0f, 15.0f}; + + // WHEN, THEN + TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_float_output), shape); +} + +TEST(CastOpTest, Int4x2ToDouble) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int4x2_input = { + Int4x2(-8, 7), // boundary values + Int4x2(0, -3), // zero and negative + Int4x2(4, -2), // positive and negative + Int4x2(1, 6) // both positive + }; + + const std::vector expected_double_output = {-8.0, 7.0, 0.0, -3.0, 4.0, -2.0, 1.0, 6.0}; + + // WHEN, THEN + TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_double_output), shape); +} + +TEST(CastOpTest, UInt4x2ToDouble) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint4x2_input = { + UInt4x2(0, 15), // boundary values + UInt4x2(1, 14), // small and large + UInt4x2(7, 8), // middle values + UInt4x2(3, 12) // mixed values + }; + + const std::vector expected_double_output = {0.0, 15.0, 1.0, 14.0, 7.0, 8.0, 3.0, 12.0}; + + // WHEN, THEN + TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_double_output), shape); +} + +TEST(CastOpTest, Int4x2ToMLFloat16) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int4x2_input = { + Int4x2(-8, 7), + Int4x2(0, -1), + Int4x2(3, -5), + Int4x2(6, 2)}; + + const std::vector expected_float16_output = + CastedValues( + gsl::make_span( + std::vector{-8.0f, 7.0f, 0.0f, -1.0f, 3.0f, -5.0f, 6.0f, 2.0f})); + + // WHEN, THEN + TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_float16_output), shape); +} + +TEST(CastOpTest, UInt4x2ToMLFloat16) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint4x2_input = { + UInt4x2(0, 15), + UInt4x2(1, 14), + UInt4x2(7, 8), + UInt4x2(3, 12)}; + + const std::vector expected_float16_output = + CastedValues( + gsl::make_span( + std::vector{0.0f, 15.0f, 1.0f, 14.0f, 7.0f, 8.0f, 3.0f, 12.0f})); + + // WHEN, THEN + TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_float16_output), shape); +} + +TEST(CastOpTest, Int4x2ToBFloat16) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int4x2_input = { + Int4x2(-8, 7), + Int4x2(0, -1), + Int4x2(3, -5), + Int4x2(6, 2)}; + + const std::vector expected_bfloat16_output = + CastedValues( + gsl::make_span( + std::vector{-8.0f, 7.0f, 0.0f, -1.0f, 3.0f, -5.0f, 6.0f, 2.0f})); + + // WHEN, THEN + TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_bfloat16_output), shape); +} + +TEST(CastOpTest, UInt4x2ToBFloat16) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint4x2_input = { + UInt4x2(0, 15), + UInt4x2(1, 14), + UInt4x2(7, 8), + UInt4x2(3, 12)}; + + const std::vector expected_bfloat16_output = + CastedValues( + gsl::make_span( + std::vector{0.0f, 15.0f, 1.0f, 14.0f, 7.0f, 8.0f, 3.0f, 12.0f})); + + // WHEN, THEN + TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_bfloat16_output), shape); +} + +TEST(CastOpTest, Int4x2ToString) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int4x2_input = { + Int4x2(-8, 7), // boundary values + Int4x2(0, -1), // zero and negative + Int4x2(3, -5), // mixed values + Int4x2(6, 2) // positive values + }; + + // Each Int4x2 becomes two string values + const std::vector expected_output = { + "-8", "7", // from first Int4x2 + "0", "-1", // from second Int4x2 + "3", "-5", // from third Int4x2 + "6", "2" // from fourth Int4x2 + }; + + // WHEN, THEN + TestCastOp(gsl::span(int4x2_input), gsl::span(expected_output), shape); +} + +TEST(CastOpTest, UInt4x2ToString) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint4x2_input = { + UInt4x2(0, 15), // boundary values + UInt4x2(8, 7), // mid-range values + UInt4x2(3, 12), // mixed values + UInt4x2(10, 5) // other values + }; + + // Each UInt4x2 becomes two string values + const std::vector expected_output = { + "0", "15", // from first UInt4x2 + "8", "7", // from second UInt4x2 + "3", "12", // from third UInt4x2 + "10", "5" // from fourth UInt4x2 + }; + + // WHEN, THEN + TestCastOp(gsl::span(uint4x2_input), gsl::span(expected_output), shape); +} + +TEST(CastOpTest, Int4x2ToUInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int4x2_input = { + Int4x2(-8, 7), // min and max values + Int4x2(0, -1), // -1 becomes max unsigned value + Int4x2(3, -5), // positive and negative values + Int4x2(6, 2) // positive values + }; + + const std::vector expected_uint4x2_output = { + UInt4x2(8, 7), // -8 becomes 8 + UInt4x2(0, 15), // -1 becomes 15 + UInt4x2(3, 11), // -5 becomes 11 + UInt4x2(6, 2) // unchanged + }; + + // WHEN, THEN + TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_uint4x2_output), shape); +} + +TEST(CastOpTest, UInt4x2ToInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint4x2_input = { + UInt4x2(0, 15), // 15 is out of int4 range + UInt4x2(1, 14), // 14 is out of int4 range + UInt4x2(7, 8), // 8 is out of int4 range + UInt4x2(3, 6) // both within range + }; + + const std::vector expected_int4x2_output = { + Int4x2(0, -1), // 15 becomes -1 + Int4x2(1, -2), // 14 becomes -2 + Int4x2(7, -8), // 8 becomes -8 + Int4x2(3, 6) // unchanged + }; + + // WHEN, THEN + TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_int4x2_output), shape); +} + +TEST(CastOpTest, Int8ToInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int8_input = {-10, 15, 0, -1, 7, -8, -128, 127}; + + const std::vector expected_int4x2_output = { + // 10 in binary is 00001010. + // Invert all bits -> 11110101, add 1 -> 11110110 + // So -10 in binary is 11110110. + // Truncate to 4 least significant bits -> 0110. + // In 4-bit two's complement, 0110 = 0 * -8 + 1 * 4 + 1 * 2 = 6. + Int4x2(6, -1), // -10 truncated to 6, 15 truncated to -1 + Int4x2(0, -1), // 0 unchanged, -1 unchanged + Int4x2(7, -8), // 7 unchanged, -8 unchanged + Int4x2(0, -1) // -128 truncated to 0, 127 truncated to -1 + }; + + // WHEN, THEN + TestCastOp(gsl::make_span(int8_input), gsl::make_span(expected_int4x2_output), shape); +} + +TEST(CastOpTest, UInt8ToUInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint8_input = {20, 255, 0, 17, 7, 240, 15, 31}; + + // values get truncated to lower 4 bits + const std::vector expected_uint4x2_output = { + UInt4x2(4, 15), // 20 (0x14) truncated to 4, 255 (0xFF) truncated to 15 + UInt4x2(0, 1), // 0 (0x00) truncated to 0, 17 (0x11) truncated to 1 + UInt4x2(7, 0), // 7 (0x07) truncated to 7, 240 (0xF0) truncated to 0 + UInt4x2(15, 15) // 15 (0x0F) truncated to 15, 31 (0x1F) truncated to 15 + }; + + // WHEN, THEN + TestCastOp(gsl::make_span(uint8_input), gsl::make_span(expected_uint4x2_output), shape); +} + +TEST(CastOpTest, Int16ToInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int16_input = {-10, 32767, 0, -32768, 7, -8, 240, 31}; + + // values get truncated to lower 4 bits and sign-extended + const std::vector expected_int4x2_output = { + Int4x2(6, -1), // -10 (0xFFF6) truncated to 6, 32767 (0x7FFF) truncated to -1 + Int4x2(0, 0), // 0 (0x0000) truncated to 0, -32768 (0x8000) truncated to 0 + Int4x2(7, -8), // 7 (0x0007) truncated to 7, -8 (0xFFF8) truncated to -8 + Int4x2(0, -1) // 240 (0x00F0) truncated to 0, 31 (0x001F) truncated to -1 + }; + + // WHEN, THEN + TestCastOp(gsl::make_span(int16_input), gsl::make_span(expected_int4x2_output), shape); +} + +TEST(CastOpTest, UInt16ToUInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint16_input = {20, 65535, 0, 256, 7, 240, 15, 4095}; + + // values get truncated to lower 4 bits + const std::vector expected_uint4x2_output = { + UInt4x2(4, 15), // 20 (0x0014) truncated to 4, 65535 (0xFFFF) truncated to 15 + UInt4x2(0, 0), // 0 (0x0000) truncated to 0, 256 (0x0100) truncated to 0 + UInt4x2(7, 0), // 7 (0x0007) truncated to 7, 240 (0x00F0) truncated to 0 + UInt4x2(15, 15) // 15 (0x000F) truncated to 15, 4095 (0x0FFF) truncated to 15 + }; + + // WHEN, THEN + TestCastOp(gsl::make_span(uint16_input), gsl::make_span(expected_uint4x2_output), shape); +} + +TEST(CastOpTest, Int32ToInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int32_input = {-10, INT32_MAX, 0, INT32_MIN, 3, -5, 4080, 287}; + + // values get truncated to lower 4 bits and sign-extended + const std::vector expected_int4x2_output = { + Int4x2(6, -1), // -10 (0xFFFFFFF6) truncated to 6, 2147483647 (0x7FFFFFFF) truncated to -1 + Int4x2(0, 0), // 0 (0x00000000) truncated to 0, -2147483648 (0x80000000) truncated to 0 + Int4x2(3, -5), // 3 (0x00000003) truncated to 3, -5 (0xFFFFFFFB) truncated to -5 + Int4x2(0, -1) // 4080 (0x00000FF0) truncated to 0, 287 (0x0000011F) truncated to -1 + }; + + // WHEN, THEN + TestCastOp(gsl::make_span(int32_input), gsl::make_span(expected_int4x2_output), shape); +} + +TEST(CastOpTest, Int32ToInt4x2OddNumberOfElements) { + // GIVEN + const std::vector odd_shape{5}; + const std::vector odd_input = {-10, INT32_MAX, 0, INT32_MIN, 4095}; + + const std::vector expected_odd_output = { + Int4x2(6, -1), // -10 truncated to 6, 2147483647 truncated to -1 + Int4x2(0, 0), // 0 truncated to 0, -2147483648 truncated to 0 + Int4x2(-1, 0) // 4095 truncated to -1, paired with 0 + }; + + // WHEN, THEN + TestCastOp(gsl::make_span(odd_input), gsl::make_span(expected_odd_output), odd_shape); +} + +TEST(CastOpTest, Int32ToInt4x2EmptyTensor) { + // GIVEN + const std::vector empty_shape{0}; + const std::vector empty_input = {}; + const std::vector empty_output = {}; + + // WHEN, THEN + TestCastOp(gsl::make_span(empty_input), gsl::make_span(empty_output), empty_shape); +} + +TEST(CastOpTest, UInt32ToUInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint32_input = {20, UINT32_MAX, 0, 256, 7, 240, 15, 4095}; + + // values get truncated to lower 4 bits + const std::vector expected_uint4x2_output = { + UInt4x2(4, 15), // 20 truncated to 4, 4294967295 truncated to 15 + UInt4x2(0, 0), // 0 truncated to 0, 256 truncated to 0 + UInt4x2(7, 0), // 7 truncated to 7, 240 truncated to 0 + UInt4x2(15, 15) // 15 truncated to 15, 4095 truncated to 15 + }; + + // WHEN, THEN + TestCastOp(gsl::make_span(uint32_input), gsl::make_span(expected_uint4x2_output), shape); +} + +TEST(CastOpTest, Int64ToInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int64_input = {-10, INT64_MAX, 0, INT64_MIN, 7, -8, 65520, 4111}; + + // values get truncated to lower 4 bits and sign-extended + const std::vector expected_int4x2_output = { + Int4x2(6, -1), // -10 truncated to 6, 9223372036854775807 truncated to -1 + Int4x2(0, 0), // 0 truncated to 0, -9223372036854775808 truncated to 0 + Int4x2(7, -8), // 7 truncated to 7, -8 truncated to -8 + Int4x2(0, -1) // 65520 truncated to 0, 4111 truncated to -1 + }; + + // WHEN, THEN + TestCastOp(gsl::make_span(int64_input), gsl::make_span(expected_int4x2_output), shape); +} + +TEST(CastOpTest, UInt64ToUInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint64_input = {20, UINT64_MAX, 0, 256, 7, 240, 15, 4095}; + + // values get truncated to lower 4 bits + const std::vector expected_uint4x2_output = { + UInt4x2(4, 15), // 20 truncated to 4, 18446744073709551615 truncated to 15 + UInt4x2(0, 0), // 0 truncated to 0, 256 truncated to 0 + UInt4x2(7, 0), // 7 truncated to 7, 240 truncated to 0 + UInt4x2(15, 15) // 15 truncated to 15, 4095 truncated to 15 + }; + + // WHEN, THEN + TestCastOp(gsl::make_span(uint64_input), gsl::make_span(expected_uint4x2_output), shape); +} + +TEST(CastOpTest, FloatToInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector float_input = {-10.7f, 15.3f, 0.4f, -1.6f, 7.0f, -8.0f, 240.1f, 31.9f}; + + const std::vector expected_int4x2_output = { + Int4x2(5, -1), // -10.7 rounded to -11 (0xF5), truncated to 5, sign-extended to 5; 15.3 rounded to 15 (0x0F), sign-extended to -1 + Int4x2(0, -2), // 0.4 rounded to 0; -1.6 rounded to -2 (0xFE), truncated to 14 (0x0E), sign-extended to -2 + Int4x2(7, -8), // 7.0 converted to 7; -8.0 converted to -8 + Int4x2(0, 0) // 240.1 rounded to 240 (0xF0), truncated to 0; 31.9 rounded to 32 (0x20), truncated to 0 + }; + + // WHEN, THEN + TestCastOp(gsl::make_span(float_input), gsl::make_span(expected_int4x2_output), shape); +} + +TEST(CastOpTest, DoubleToUInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector double_input = {20.7, 255.3, 0.4, 1.6, 7.8, 240.2, 15.1, 31.9}; + + const std::vector expected_uint4x2_output = { + UInt4x2(5, 15), // 20.7 rounded to 21, truncated to 5; 255.3 rounded to 255, truncated to 15 + UInt4x2(0, 2), // 0.4 rounded to 0; 1.6 rounded to 2 + UInt4x2(8, 0), // 7.8 rounded to 8; 240.2 rounded to 240, truncated to 0 + UInt4x2(15, 0) // 15.1 rounded to 15; 31.9 rounded to 32, truncated to 0 + }; + + // WHEN, THEN + TestCastOp(gsl::make_span(double_input), gsl::make_span(expected_uint4x2_output), shape); +} + +TEST(CastOpTest, MLFloat16ToInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const MLFloat16 mlfloat16_array[8] = { + MLFloat16(static_cast(-10.7f)), + MLFloat16(static_cast(15.3f)), + MLFloat16(static_cast(0.4f)), + MLFloat16(static_cast(-1.6f)), + MLFloat16(static_cast(3.8f)), + MLFloat16(static_cast(-5.2f)), + MLFloat16(static_cast(240.1f)), + MLFloat16(static_cast(31.9f))}; + + const std::vector expected_int4x2 = { + Int4x2(5, -1), // -10.7 rounded to -11 (0xF5), truncated to 5; 15.3 rounded to 15 (0x0F), sign-extended to -1 + Int4x2(0, -2), // 0.4 rounded to 0; -1.6 rounded to -2 (0xFE), truncated to 14 (0x0E), sign-extended to -2 + Int4x2(4, -5), // 3.8 rounded to 4; -5.2 rounded to -5 (0xFB), truncated to 11 (0x0B), sign-extended to -5 + Int4x2(0, 0) // 240.1 rounded to 240 (0xF0), truncated to 0; 31.9 rounded to 32 (0x20), truncated to 0 + }; + + // WHEN, THEN + TestCastOp( + gsl::span(mlfloat16_array, 8), + gsl::span(expected_int4x2), + shape); +} + +TEST(CastOpTest, MLFloat16ToUInt4x2) { + // GIVEN + // 8 MLFloat16 values will compress to 4 UInt4x2 values + const std::vector shape{2, 4}; // Shape that contains 8 elements + + // MLFloat16 values with edge cases and truncation scenarios + const MLFloat16 mlfloat16_array[8] = { + MLFloat16(static_cast(20.7f)), + MLFloat16(static_cast(255.3f)), + MLFloat16(static_cast(0.4f)), + MLFloat16(static_cast(1.6f)), + MLFloat16(static_cast(7.8f)), + MLFloat16(static_cast(240.2f)), + MLFloat16(static_cast(15.1f)), + MLFloat16(static_cast(31.9f))}; + + const std::vector expected_uint4x2 = { + UInt4x2(5, 15), // 20.7 rounded to 21, truncated to 5; 255.3 rounded to 255, truncated to 15 + UInt4x2(0, 2), // 0.4 rounded to 0; 1.6 rounded to 2 + UInt4x2(8, 0), // 7.8 rounded to 8; 240.2 rounded to 240, truncated to 0 + UInt4x2(15, 0) // 15.1 rounded to 15; 31.9 rounded to 32, truncated to 0 + }; + + // WHEN, THEN + TestCastOp( + gsl::span(mlfloat16_array, 8), + gsl::span(expected_uint4x2), + shape); +} + +TEST(CastOpTest, MLFloat16ToInt4x2BoundaryValues) { + // GIVEN + // Test MLFloat16 values that need truncation to Int4x2 range + const std::vector shape{3, 2}; + const MLFloat16 mlfloat16_array[6] = { + MLFloat16(static_cast(-10)), // Truncated to lower 4 bits + MLFloat16(static_cast(9)), // Truncated to lower 4 bits + MLFloat16(static_cast(-8)), // Truncated to lower 4 bits + MLFloat16(static_cast(7)), // Truncated to lower 4 bits + MLFloat16(static_cast(-0.6f)), // Should round to -1 + MLFloat16(static_cast(1.7f)) // Should round to 2 + }; + + // Values get truncated to lower 4 bits and sign-extended + const std::vector expected_int4x2 = { + Int4x2(6, -7), // -10 (0xFFFFFFF6) truncated to 6, 9 (0x00000009) truncated to -7 + Int4x2(-8, 7), // -8 (0xFFFFFFF8) truncated to -8, 7 (0x00000007) truncated to 7 + Int4x2(-1, 2) // -0.6 rounds to -1, 1.7 rounds to 2 + }; + + // WHEN, THEN + TestCastOp( + gsl::span(mlfloat16_array, 6), + gsl::span(expected_int4x2), + shape); +} + +TEST(CastOpTest, MLFloat16ToUInt4x2BoundaryValues) { + // GIVEN + // Test MLFloat16 values that need truncation to UInt4x2 range + const std::vector shape{3, 2}; // Shape that contains 6 elements + const MLFloat16 mlfloat16_array[6] = { + MLFloat16(static_cast(-5)), // Negative, truncated to lower 4 bits + MLFloat16(static_cast(20)), // Above max, truncated to lower 4 bits + MLFloat16(static_cast(0)), // At min, should remain 0 + MLFloat16(static_cast(15)), // At max, should remain 15 + MLFloat16(static_cast(3.4f)), // Should round to 3 + MLFloat16(static_cast(5.7f)) // Should round to 6 + }; + + // Values get truncated to lower 4 bits (no sign extension for unsigned) + const std::vector expected_uint4x2 = { + UInt4x2(11, 4), // -5 (0xFFFFFFFB) truncated to 11, 20 (0x00000014) truncated to 4 + UInt4x2(0, 15), // 0 and 15 already within range + UInt4x2(3, 6) // 3.4 rounds to 3, 5.7 rounds to 6 + }; + + // WHEN, THEN + TestCastOp( + gsl::span(mlfloat16_array, 6), + gsl::span(expected_uint4x2), + shape); +} + +TEST(CastOpTest, BFloat16ToInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const BFloat16 bfloat16_array[8] = { + BFloat16(static_cast(-10.7f)), + BFloat16(static_cast(15.3f)), + BFloat16(static_cast(0.4f)), + BFloat16(static_cast(-1.6f)), + BFloat16(static_cast(3.8f)), + BFloat16(static_cast(-5.2f)), + BFloat16(static_cast(240.1f)), + BFloat16(static_cast(31.9f))}; + + const std::vector expected_int4x2 = { + Int4x2(5, -1), // -10.7 rounded to -11 (0xF5), truncated to 5; 15.3 rounded to 15 (0x0F), sign-extended to -1 + Int4x2(0, -2), // 0.4 rounded to 0; -1.6 rounded to -2 (0xFE), truncated to 14 (0x0E), sign-extended to -2 + Int4x2(4, -5), // 3.8 rounded to 4; -5.2 rounded to -5 (0xFB), truncated to 11 (0x0B), sign-extended to -5 + Int4x2(0, 0) // 240.1 rounded to 240 (0xF0), truncated to 0; 31.9 rounded to 32 (0x20), truncated to 0 + }; + + // WHEN, THEN + TestCastOp( + gsl::span(bfloat16_array, 8), + gsl::span(expected_int4x2), + shape); +} + +TEST(CastOpTest, BFloat16ToUInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const BFloat16 bfloat16_array[8] = { + BFloat16(static_cast(20.7f)), + BFloat16(static_cast(255.3f)), + BFloat16(static_cast(0.4f)), + BFloat16(static_cast(1.6f)), + BFloat16(static_cast(7.8f)), + BFloat16(static_cast(240.2f)), + BFloat16(static_cast(15.1f)), + BFloat16(static_cast(31.9f))}; + + const std::vector expected_uint4x2 = { + UInt4x2(5, 15), // 20.7 rounded to 21, truncated to 5; 255.3 rounded to 255, truncated to 15 + UInt4x2(0, 2), // 0.4 rounded to 0; 1.6 rounded to 2 + UInt4x2(8, 0), // 7.8 rounded to 8; 240.2 rounded to 240, truncated to 0 + UInt4x2(15, 0) // 15.1 rounded to 15; 31.9 rounded to 32, truncated to 0 + }; + + // WHEN, THEN + TestCastOp( + gsl::span(bfloat16_array, 8), + gsl::span(expected_uint4x2), + shape); +} + +TEST(CastOpTest, BFloat16ToUInt4x2BoundaryValues) { + // GIVEN + const std::vector shape{3, 2}; + const BFloat16 bfloat16_array[6] = { + BFloat16(static_cast(-5)), // Negative, truncated to lower 4 bits + BFloat16(static_cast(20)), // Above max, truncated to lower 4 bits + BFloat16(static_cast(0)), // At min, should remain 0 + BFloat16(static_cast(15)), // At max, should remain 15 + BFloat16(static_cast(3.4f)), // Should round to 3 + BFloat16(static_cast(5.7f)) // Should round to 6 + }; + + // Values get truncated to lower 4 bits (no clamping for consistency) + const std::vector expected_uint4x2 = { + UInt4x2(11, 4), // -5 (0xFFFFFFFB) truncated to 11, 20 (0x00000014) truncated to 4 + UInt4x2(0, 15), // 0 and 15 already within range + UInt4x2(3, 6) // 3.4 rounds to 3, 5.7 rounds to 6 + }; + + // WHEN, THEN + TestCastOp( + gsl::span(bfloat16_array, 6), + gsl::span(expected_uint4x2), + shape); +} + +TEST(CastOpTest, BoolToInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const bool bool_input[] = {false, true, true, false, false, true, true, true}; + const gsl::span bool_input_span(bool_input); + + const std::vector expected_int4x2_output = { + Int4x2(0, 1), + Int4x2(1, 0), + Int4x2(0, 1), + Int4x2(1, 1)}; + + // WHEN, THEN + TestCastOp(bool_input_span, gsl::make_span(expected_int4x2_output), shape); +} + +TEST(CastOpTest, BoolToUInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const bool bool_input[] = {false, true, true, false, false, true, true, true}; + const gsl::span bool_input_span(bool_input); + + const std::vector expected_uint4x2_output = { + UInt4x2(0, 1), + UInt4x2(1, 0), + UInt4x2(0, 1), + UInt4x2(1, 1)}; + + // WHEN, THEN + TestCastOp(bool_input_span, gsl::make_span(expected_uint4x2_output), shape); +} + +TEST(CastOpTest, StringToInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector string_input = { + "-8", "7", // boundary values + "0", "-1", // zero and negative + "3", "-5", // mixed values + "6", "2" // positive values + }; + + const std::vector expected_output{ + Int4x2(-8, 7), + Int4x2(0, -1), + Int4x2(3, -5), + Int4x2(6, 2)}; + + // WHEN, THEN + TestCastOp(gsl::span(string_input), gsl::span(expected_output), shape); +} + +TEST(CastOpTest, StringToUInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector string_input = { + "0", "15", // boundary values + "8", "7", // mid-range values + "3", "12", // mixed values + "10", "5" // other values + }; + + const std::vector expected_output{ + UInt4x2(0, 15), + UInt4x2(8, 7), + UInt4x2(3, 12), + UInt4x2(10, 5)}; + + // WHEN, THEN + TestCastOp(gsl::span(string_input), gsl::span(expected_output), shape); +} + +TEST(CastOpTest, StringToUInt4x2BoundaryValues) { + // GIVEN + // Test string values that need truncation to UInt4x2 range + const std::vector shape{3, 2}; + const std::vector string_input = { + "-5", "20", // out of range values that get truncated + "16", "100", // out of range values that get truncated + "0", "15" // boundary values that are in range + }; + + // Each pair of strings becomes one UInt4x2 + // Values get truncated to lower 4 bits (no sign extension for unsigned) + const std::vector expected_output{ + UInt4x2(11, 4), // -5 (0xFFFFFFFB) truncated to 11, 20 (0x00000014) truncated to 4 + UInt4x2(0, 4), // 16 (0x00000010) truncated to 0, 100 (0x00000064) truncated to 4 + UInt4x2(0, 15) // 0 and 15 already in range + }; + + // WHEN, THEN + TestCastOp(gsl::span(string_input), gsl::span(expected_output), shape); +} + +TEST(CastOpTest, FloatStringToInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector string_input = { + "-10.7", "255.3", + "0.4", "2", + "6.8", "240.2", + "15.0", "-8"}; + + const std::vector expected_int4x2_output = { + Int4x2(5, -1), // -11 -> 5, 255 -> -1 + Int4x2(0, 2), + Int4x2(7, 0), + Int4x2(-1, -8)}; + + // WHEN, THEN + TestCastOp(gsl::span(string_input), gsl::span(expected_int4x2_output), shape); +} + #if !defined(DISABLE_FLOAT8_TYPES) template @@ -269,6 +1299,164 @@ TEST(CastOpTest, ToFloat8E5M2FNUZ) { } } +TEST(CastOpTest, Int4x2ToFloat8E4M3FN) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int4x2_input = { + Int4x2(-8, 7), + Int4x2(0, -1), + Int4x2(3, -5), + Int4x2(6, 2)}; + + std::vector expected_float8_output; + expected_float8_output.reserve(8); + const std::vector float_values = {-8.0f, 7.0f, 0.0f, -1.0f, 3.0f, -5.0f, 6.0f, 2.0f}; + for (float val : float_values) { + expected_float8_output.emplace_back(Float8E4M3FN(val, true)); + } + + // WHEN, THEN + // Test with Saturate::None, which means the 'saturate_' bool inside the 'Cast' class defaults to 1 + TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_float8_output), shape); + // Test with Saturate::False, which means the 'saturate_' bool inside the 'Cast' class will be 0 + TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_float8_output), shape, + OpTester::ExpectResult::kExpectSuccess, "", 21, Saturate::False); +} + +TEST(CastOpTest, UInt4x2ToFloat8E4M3FN) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint4x2_input = { + UInt4x2(0, 15), + UInt4x2(1, 14), + UInt4x2(7, 8), + UInt4x2(3, 12)}; + + std::vector expected_uint_float8_output; + expected_uint_float8_output.reserve(8); + const std::vector uint_float_values = {0.0f, 15.0f, 1.0f, 14.0f, 7.0f, 8.0f, 3.0f, 12.0f}; + for (float val : uint_float_values) { + expected_uint_float8_output.emplace_back(Float8E4M3FN(val, true)); + } + + // WHEN, THEN + // Test with Saturate::None, which means the 'saturate_' bool inside the 'Cast' class defaults to 1 + TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_uint_float8_output), shape); + // Test with Saturate::False, which means the 'saturate_' bool inside the 'Cast' class will be 0 + TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_uint_float8_output), shape, + OpTester::ExpectResult::kExpectSuccess, "", 21, Saturate::False); +} + +TEST(CastOpTest, Int4x2ToFloat8E5M2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int4x2_input = { + Int4x2(-8, 7), + Int4x2(0, -1), + Int4x2(3, -5), + Int4x2(6, 2)}; + + std::vector expected_float8e5m2_output; + expected_float8e5m2_output.reserve(8); + const std::vector float_values = {-8.0f, 7.0f, 0.0f, -1.0f, 3.0f, -5.0f, 6.0f, 2.0f}; + for (float val : float_values) { + expected_float8e5m2_output.emplace_back(Float8E5M2(val, true)); + } + + // WHEN, THEN + // Test with Saturate::None, which means the 'saturate_' bool inside the 'Cast' class defaults to 1 + TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_float8e5m2_output), shape); + // Test with Saturate::False, which means the 'saturate_' bool inside the 'Cast' class will be 0 + TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_float8e5m2_output), shape, + OpTester::ExpectResult::kExpectSuccess, "", 21, Saturate::False); +} + +TEST(CastOpTest, UInt4x2ToFloat8E5M2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint4x2_input = { + UInt4x2(0, 15), + UInt4x2(1, 14), + UInt4x2(7, 8), + UInt4x2(3, 12)}; + + std::vector expected_uint_float8e5m2_output; + expected_uint_float8e5m2_output.reserve(8); + const std::vector uint_float_values = {0.0f, 15.0f, 1.0f, 14.0f, 7.0f, 8.0f, 3.0f, 12.0f}; + for (float val : uint_float_values) { + expected_uint_float8e5m2_output.emplace_back(Float8E5M2(val, true)); + } + + // WHEN, THEN + // Test with Saturate::None, which means the 'saturate_' bool inside the 'Cast' class defaults to 1 + TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_uint_float8e5m2_output), shape); + // Test with Saturate::False, which means the 'saturate_' bool inside the 'Cast' class will be 0 + TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_uint_float8e5m2_output), shape, + OpTester::ExpectResult::kExpectSuccess, "", 21, Saturate::False); +} + +TEST(CastOpTest, Float8E4M3FNToInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + std::vector float8_input; + const std::vector input_values = {-8.0f, 7.0f, 0.0f, -1.0f, 3.0f, -5.0f, 6.0f, 2.0f}; + for (float val : input_values) { + float8_input.emplace_back(Float8E4M3FN(val, true)); + } + + const std::vector expected_int4x2_output = { + Int4x2(-8, 7), + Int4x2(0, -1), + Int4x2(3, -5), + Int4x2(6, 2)}; + + // WHEN, THEN + // The 'saturate_' bool inside the 'Cast' class can only be false if the conversion is to a float 8 type, + // so it's sufficient to test with the default saturate = 1 here, since we are not converting to float 8. + TestCastOp(gsl::make_span(float8_input), gsl::make_span(expected_int4x2_output), shape); +} + +TEST(CastOpTest, Float8E4M3FNToInt4x2_OddShape) { + // GIVEN + const std::vector shape{1, 2, 3}; + std::vector float8_input; + const std::vector input_values = {-8.0f, 7.0f, 0.0f, -1.0f, 3.0f, -5.0f}; + for (float val : input_values) { + float8_input.emplace_back(Float8E4M3FN(val, true)); + } + + const std::vector expected_int4x2_output = { + Int4x2(-8, 7), + Int4x2(0, -1), + Int4x2(3, -5)}; + + // WHEN, THEN + // The 'saturate_' bool inside the 'Cast' class can only be false if the conversion is to a float 8 type, + // so it's sufficient to test with the default saturate = 1 here, since we are not converting to float 8. + TestCastOp(gsl::make_span(float8_input), gsl::make_span(expected_int4x2_output), shape); +} + +TEST(CastOpTest, Float8E4M3FNToUInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + std::vector uint_float8_input; + const std::vector uint_input_values = {0.0f, 15.0f, 1.0f, 14.0f, 7.0f, 8.0f, 3.0f, 12.0f}; + for (float val : uint_input_values) { + uint_float8_input.emplace_back(Float8E4M3FN(val, true)); + } + + const std::vector expected_uint4x2_output = { + UInt4x2(0, 15), + UInt4x2(1, 14), + UInt4x2(7, 8), + UInt4x2(3, 12)}; + + // WHEN, THEN + // The 'saturate_' bool inside the 'Cast' class can only be false if the conversion is to a float 8 type, + // so it's sufficient to test with the default saturate = 1 here, since we are not converting to float 8. + TestCastOp(gsl::make_span(uint_float8_input), gsl::make_span(expected_uint4x2_output), shape); +} + #endif } // namespace test