diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index fa6c731231405..8659c96b540c8 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -58,11 +58,11 @@ Do not modify directly.*
|BitwiseOr|*in* A:**T**
*in* B:**T**
*out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|BitwiseXor|*in* A:**T**
*in* B:**T**
*out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|BlackmanWindow|*in* size:**T1**
*out* output:**T2**|17+|**T1** = tensor(int32), tensor(int64)
**T2** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
-|Cast|*in* input:**T1**
*out* output:**T2**|23+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
-|||[21, 22]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
-|||[19, 20]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
-|||[13, 18]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
-|||[6, 12]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|Cast|*in* input:**T1**
*out* output:**T2**|23+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
+|||[21, 22]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
+|||[19, 20]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
+|||[13, 18]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
+|||[6, 12]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
|Ceil|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float)|
|||[6, 12]|**T** = tensor(double), tensor(float)|
|Celu|*in* X:**T**
*out* Y:**T**|12+|**T** = tensor(float)|
diff --git a/include/onnxruntime/core/framework/data_types_internal.h b/include/onnxruntime/core/framework/data_types_internal.h
index 05f4c10995ef2..4cc57ba4b5391 100644
--- a/include/onnxruntime/core/framework/data_types_internal.h
+++ b/include/onnxruntime/core/framework/data_types_internal.h
@@ -319,6 +319,10 @@ class CallableDispatchableHelper {
public:
explicit CallableDispatchableHelper(int32_t dt_type) noexcept : dt_type_(dt_type), called_(0) {}
+#if defined(_MSC_VER)
+#pragma warning(push)
+#pragma warning(disable : 4702)
+#endif
// Must return integer to be in a expandable context
template
int Invoke(Fn&& fn, Args&&... args) {
@@ -328,6 +332,9 @@ class CallableDispatchableHelper {
}
return 0;
}
+#if defined(_MSC_VER)
+#pragma warning(pop)
+#endif
void CheckCalledOnce() const {
ORT_ENFORCE(called_ == 1, "Unsupported data type: ", dt_type_);
@@ -338,7 +345,7 @@ class CallableDispatchableHelper {
// Other policies may set the second result argument accordingly.
template
struct UnsupportedTypeDefaultPolicy {
- void operator()(int32_t dt_type, Ret& /*result*/) const {
+ [[noreturn]] void operator()(int32_t dt_type, Ret& /*result*/) const {
ORT_THROW("Unsupported data type: ", dt_type);
}
};
diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc
index d1c280d9886f4..685937049e58f 100644
--- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc
+++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc
@@ -31,7 +31,7 @@ namespace op_kernel_type_control {
// we're using one set of types for all opsets of Cast
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, Cast, Input, 0,
- element_type_lists::AllIRv9);
+ element_type_lists::AllIRv10);
ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPES_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, Cast, Input, 0,
@@ -39,7 +39,7 @@ ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPES_ALL_OPSETS(
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, Cast, Output, 0,
- element_type_lists::AllIRv9);
+ element_type_lists::AllIRv10);
ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPES_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, Cast, Output, 0,
@@ -58,8 +58,43 @@ using IsOrtFloat16Type = boost::mp11::mp_contains,
#if !defined(DISABLE_FLOAT8_TYPES)
template
using IsOrtFloat8Type = boost::mp11::mp_contains;
+#else
+template
+struct IsOrtFloat8Type : std::false_type {};
#endif
+template
+using IsOrtInt4Type = boost::mp11::mp_contains, T>;
+
+template
+struct IsStandardIntegerType {
+ static constexpr bool value =
+ std::is_same_v ||
+ std::is_same_v ||
+ std::is_same_v ||
+ std::is_same_v ||
+ std::is_same_v ||
+ std::is_same_v ||
+ std::is_same_v ||
+ std::is_same_v;
+};
+
+// Types that Int4x2 and UInt4x2 convert to and from, apart from string.
+template
+struct IsOrtInt4NumericConversionType {
+ static constexpr bool value =
+ std::is_same_v ||
+ IsStandardIntegerType::value ||
+ std::is_floating_point_v ||
+ IsOrtFloat16Type::value ||
+ IsOrtFloat8Type::value;
+};
+
+template
+struct IsOrtInt4ConversionType {
+ static constexpr bool value = IsOrtInt4NumericConversionType::value || std::is_same_v;
+};
+
// string cast helpers
// Note: when C++17 is available, use functions
@@ -115,11 +150,7 @@ CastToString(const SrcType& input, std::string& output) {
}
template
-#if !defined(DISABLE_FLOAT8_TYPES)
typename std::enable_if::value || IsOrtFloat8Type::value, void>::type
-#else
-typename std::enable_if::value>::type
-#endif
CastToString(const SrcType& input, std::string& output) {
CastToString(static_cast(input), output);
}
@@ -149,11 +180,7 @@ CastFromString(const std::string& input, DstType& output) {
}
template
-#if !defined(DISABLE_FLOAT8_TYPES)
typename std::enable_if::value || IsOrtFloat8Type::value, void>::type
-#else
-typename std::enable_if::value, void>::type
-#endif
CastFromString(const std::string& input, DstType& output) {
float intermediate;
CastFromString(input, intermediate);
@@ -177,6 +204,134 @@ template <>
struct EigenCastType {
using type = Eigen::bfloat16;
};
+
+// Helper for converting (U)Int4x2 values to any destination type.
+template ::value && IsOrtInt4ConversionType::value>>
+struct FromInt4Converter {
+ // The input 'val' can be either an int8_t value coming from Int4x2.GetElem(pos),
+ // or an uint8_t value coming from UInt4x2.GetElem(pos), where pos can be 0 or 1.
+ static DstType Convert(typename SrcType::UnpackedType val) {
+ if constexpr (IsOrtFloat16Type::value) {
+ return DstType(static_cast(val));
+ } else if constexpr (IsOrtFloat8Type::value) {
+ return DstType(static_cast(val), true);
+ } else if constexpr (std::is_same_v) {
+ return val != 0;
+ } else if constexpr (std::is_same_v) {
+ // val has type (u)int8_t, so static_cast is required in order for std::to_string
+ // to interpret val as a number (65 -> "65"), instead of a char (65 -> "A").
+ return std::to_string(static_cast(val));
+ } else {
+ return static_cast(val);
+ }
+ }
+};
+
+// Helper for converting any source type to (U)Int4x2::UnpackedType values (int8_t and uint8_t).
+template ::value && IsOrtInt4Type::value>>
+struct ToInt4Converter {
+ static typename DstType::UnpackedType Convert(const SrcType& val);
+};
+
+// See https://onnx.ai/onnx/operators/onnx__Cast.html#summary for casting from
+// fixed point to fixed point: when OOR, discard higher bits and reinterpret
+// (with respect to two's complement representation for signed types).
+// The following example is listed: 200 (int16) converts to -56 (int8).
+// For our int4 conversion, 200 (int16) would convert to -8 (int4).
+template
+struct ToInt4Converter::value>> {
+ static int8_t Convert(const SrcType& val) {
+ // Example: int8_t(14) converts to int4 (-2)
+ // int8_t(14) is 0000_1110
+ // truncate: 0000_1110 & 0000_1111 = 0000_1110
+ // in 4-bit two's complement, 1110 = 1 * -8 + 1 * 4 + 1 * 2 + 1 * 0 = -2
+ // sign-extend: -2 in int8_t is 1111_0000 | 0000_1110 = 1111_1110
+
+ // Truncate to 4 least significant bits
+ uint8_t truncated = static_cast(val) & 0x0F;
+
+ // Sign-extend: if bit 3 is set, it's negative in 4-bit two's complement,
+ // so set the 4 most significant bits to 1.
+ return static_cast((truncated & 0x8) ? (truncated | 0xF0) : truncated);
+ }
+};
+
+// See https://onnx.ai/onnx/operators/onnx__Cast.html#summary for casting from
+// fixed point to fixed point: when OOR, discard higher bits and reinterpret
+// (with respect to two's complement representation for signed types).
+template
+struct ToInt4Converter::value>> {
+ static uint8_t Convert(const SrcType& val) {
+ // Truncate to 4 least significant bits
+ return static_cast(val) & 0x0F;
+ }
+};
+
+// bool -> (U)Int4x2
+template
+struct ToInt4Converter::value>> {
+ static typename DstType::UnpackedType Convert(bool val) {
+ return static_cast(val ? 1 : 0);
+ }
+};
+
+// float -> (U)Int4x2
+// Per https://onnx.ai/onnx/operators/onnx__Cast.html#summary, casting from
+// floating point to fixed point is undefined if OOR.
+template
+struct ToInt4Converter::value>> {
+ static typename DstType::UnpackedType Convert(const float& val) {
+ int result = static_cast(std::roundf(val));
+ return ToInt4Converter::Convert(result);
+ }
+};
+
+// double -> (U)Int4x2
+template
+struct ToInt4Converter::value>> {
+ static typename DstType::UnpackedType Convert(const double& val) {
+ int result = static_cast(std::round(val));
+ return ToInt4Converter::Convert(result);
+ }
+};
+
+// float 8 -> (U)Int4x2
+template
+struct ToInt4Converter::value && IsOrtInt4Type::value>> {
+ static typename DstType::UnpackedType Convert(const SrcType& val) {
+ float result = val.ToFloat();
+ return ToInt4Converter::Convert(result);
+ }
+};
+
+// float 16 -> (U)Int4x2
+template
+struct ToInt4Converter::value && IsOrtInt4Type::value>> {
+ static typename DstType::UnpackedType Convert(const SrcType& val) {
+ float f_val = static_cast(val);
+ return ToInt4Converter::Convert(f_val);
+ }
+};
+
+// string -> (U)Int4x2
+template
+struct ToInt4Converter::value>> {
+ static typename DstType::UnpackedType Convert(const std::string& val) {
+ double result = std::stod(val);
+ return ToInt4Converter::Convert(result);
+ }
+};
+
// generic tensor X -> Y
template
struct TensorCaster {
@@ -193,9 +348,10 @@ struct TensorCaster {
}
};
-// tensor X -> string
+// tensor X -> string, if X != (U)Int4x2
template
-struct TensorCaster {
+struct TensorCaster::value>> {
void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const {
const std::ptrdiff_t shape_size = narrow(shape.Size());
const auto* in_data = in.Data();
@@ -206,9 +362,10 @@ struct TensorCaster {
}
};
-// tensor string -> X
+// tensor string -> X, if X != (U)Int4x2
template
-struct TensorCaster {
+struct TensorCaster::value>> {
void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const {
const std::ptrdiff_t shape_size = narrow(shape.Size());
const auto* in_data = in.Data();
@@ -219,46 +376,105 @@ struct TensorCaster {
}
};
-#if !defined(DISABLE_FLOAT8_TYPES)
+// tensor MLFloat16 -> float
+template <>
+struct TensorCaster {
+ void Cast(const OpKernelContext& ctx, const TensorShape& shape, const Tensor& in, Tensor& out) const {
+ auto out_data = out.MutableData();
+ auto in_data = in.Data();
+ const size_t shape_size = narrow(shape.Size());
+ MlasConvertHalfToFloatBufferInParallel(in_data, out_data, shape_size, ctx.GetOperatorThreadPool());
+ }
+};
-// tensor X -> float 8
-template
-struct TensorCasterNoSat {
+// tensor float -> MLFloat16
+template <>
+struct TensorCaster {
void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const {
- const std::ptrdiff_t shape_size = narrow(shape.Size());
+ auto in_data = in.Data();
+ auto out_data = out.MutableData();
+ const size_t shape_size = narrow(shape.Size());
+ MlasConvertFloatToHalfBuffer(in_data, out_data, shape_size);
+ }
+};
+
+// (U)Int4x2 -> string or numeric types
+template
+struct TensorCaster::value && IsOrtInt4ConversionType::value>> {
+ void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const {
+ const ptrdiff_t shape_size = narrow(shape.Size());
const auto* in_data = in.Data();
auto* out_data = out.MutableData();
- for (std::ptrdiff_t i = 0; i < shape_size; ++i) {
- out_data[i] = DstType(static_cast(in_data[i]), false);
+
+ for (ptrdiff_t i = 0; i < shape_size; ++i) {
+ // elem 0 is the low nibble, 1 the high nibble
+ auto val = in_data[i >> 1].GetElem(i & 0x1);
+ out_data[i] = FromInt4Converter::Convert(val);
}
}
};
-// tensor string -> float 8
-template
-struct TensorCasterNoSat {
+// string or numeric types -> (U)Int4x2
+template
+struct TensorCaster::value && IsOrtInt4Type::value>> {
void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const {
- const std::ptrdiff_t shape_size = narrow(shape.Size());
- const auto* in_data = in.Data();
+ const ptrdiff_t shape_size = narrow(shape.Size());
+ const auto* in_data = in.Data();
auto* out_data = out.MutableData();
- float float_value;
- for (std::ptrdiff_t i = 0; i < shape_size; ++i) {
- CastFromString(in_data[i], float_value);
- out_data[i] = DstType(float_value, false);
+
+ ptrdiff_t i = 0;
+ for (; i < shape_size - 1; i += 2) {
+ auto low_val = ToInt4Converter::Convert(in_data[i]);
+ auto high_val = ToInt4Converter::Convert(in_data[i + 1]);
+ out_data[i >> 1] = DstType(low_val, high_val);
+ }
+
+ if (i < shape_size) {
+ auto low_val = ToInt4Converter::Convert(in_data[i]);
+ out_data[i >> 1] = DstType(low_val, 0);
}
}
};
-#endif
+// Int4x2 -> UInt4x2
+template <>
+struct TensorCaster {
+ void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const {
+ const ptrdiff_t shape_size = narrow(shape.Size() + 1) >> 1;
+ const auto* in_data = in.Data();
+ auto* out_data = out.MutableData();
-// tensor MLFloat16 -> float
+ for (ptrdiff_t i = 0; i < shape_size; ++i) {
+ auto low_nibble = in_data[i].GetElem(0);
+ auto high_nibble = in_data[i].GetElem(1);
+
+ uint8_t low_unsigned = static_cast(low_nibble) & 0x0F;
+ uint8_t high_unsigned = static_cast(high_nibble) & 0x0F;
+
+ out_data[i] = UInt4x2(low_unsigned, high_unsigned);
+ }
+ }
+};
+
+// UInt4x2 -> Int4x2
template <>
-struct TensorCaster {
- void Cast(const OpKernelContext& ctx, const TensorShape& shape, const Tensor& in, Tensor& out) const {
- auto out_data = out.MutableData();
- auto in_data = in.Data();
- const size_t shape_size = narrow(shape.Size());
- MlasConvertHalfToFloatBufferInParallel(in_data, out_data, shape_size, ctx.GetOperatorThreadPool());
+struct TensorCaster {
+ void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const {
+ const ptrdiff_t shape_size = narrow(shape.Size() + 1) >> 1;
+ const auto* in_data = in.Data();
+ auto* out_data = out.MutableData();
+
+ for (ptrdiff_t i = 0; i < shape_size; ++i) {
+ auto low_nibble = in_data[i].GetElem(0);
+ auto high_nibble = in_data[i].GetElem(1);
+
+ int8_t low_signed = static_cast((low_nibble & 0x0F) << 4) >> 4;
+ int8_t high_signed = static_cast((high_nibble & 0x0F) << 4) >> 4;
+
+ out_data[i] = Int4x2(low_signed, high_signed);
+ }
}
};
@@ -284,7 +500,8 @@ void CastMLFloat16ThroughFloatTensor(
// tensor MLFloat16 -> X
template
-struct TensorCaster {
+struct TensorCaster::value>> {
void Cast(const OpKernelContext& context, const TensorShape& shape, const Tensor& in, Tensor& out) const {
CastMLFloat16ThroughFloatTensor(context, shape, in, out);
}
@@ -299,6 +516,54 @@ struct TensorCaster {
};
#endif
+#if !defined(DISABLE_FLOAT8_TYPES)
+// TensorCasterNoSat is only called when all the below conditions are met (see Cast::Compute):
+// - defined(DISABLE_FLOAT8_TYPES) == false
+// - saturate_ == false
+// - IsOrtFloat8Type::value == true
+
+// tensor X -> float 8
+template ::value>>
+struct TensorCasterNoSat {
+ void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const {
+ const std::ptrdiff_t shape_size = narrow(shape.Size());
+ const auto* in_data = in.Data();
+ auto* out_data = out.MutableData();
+ for (std::ptrdiff_t i = 0; i < shape_size; ++i) {
+ out_data[i] = DstType(static_cast(in_data[i]), false);
+ }
+ }
+};
+
+// tensor (U)Int4x2 -> float 8
+template
+struct TensorCasterNoSat::value && IsOrtFloat8Type::value>> {
+ void Cast(const OpKernelContext& context, const TensorShape& shape, const Tensor& src, Tensor& dst) const {
+ // Int4x2/UInt4x2 always fit inside any Float8 type, so we can reuse the saturate = true implementation.
+ TensorCaster{}.Cast(context, shape, src, dst);
+ }
+};
+
+// tensor string -> float 8
+template
+struct TensorCasterNoSat::value>> {
+ void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const {
+ const std::ptrdiff_t shape_size = narrow(shape.Size());
+ const auto* in_data = in.Data();
+ auto* out_data = out.MutableData();
+ float float_value;
+ for (std::ptrdiff_t i = 0; i < shape_size; ++i) {
+ CastFromString(in_data[i], float_value);
+ out_data[i] = DstType(float_value, false);
+ }
+ }
+};
+
+#endif
+
class Cast final : public OpKernel {
public:
Cast(const OpKernelInfo& info) : OpKernel(info) {
diff --git a/onnxruntime/test/onnx/TestCase.cc b/onnxruntime/test/onnx/TestCase.cc
index a74ecacc1f26e..647c947f37f0c 100644
--- a/onnxruntime/test/onnx/TestCase.cc
+++ b/onnxruntime/test/onnx/TestCase.cc
@@ -948,6 +948,16 @@ std::unique_ptr> GetBrokenTests(const std::string& provider
{"slice_neg_steps",
"Type parameter (Tind) bound to different types (tensor(int64) and tensor(int32) in node ()."},
{"cast_BFLOAT16_to_FLOAT", "Unexpected input data type"},
+ {"cast_FLOAT16_to_INT4", "Skipped until onnxruntime/cmake/external/onnx points to onnx 1.19 which should include @onnx/onnx/pull/7074"},
+ {"cast_FLOAT16_to_UINT4", "Skipped until onnxruntime/cmake/external/onnx points to onnx 1.19 which should include @onnx/onnx/pull/7074"},
+ {"cast_FLOAT_to_INT4", "Skipped until onnxruntime/cmake/external/onnx points to onnx 1.19 which should include @onnx/onnx/pull/7074"},
+ {"cast_FLOAT_to_UINT4", "Skipped until onnxruntime/cmake/external/onnx points to onnx 1.19 which should include @onnx/onnx/pull/7074"},
+ {"cast_INT4_to_FLOAT", "Skipped until onnxruntime/cmake/external/onnx points to onnx 1.19 which should include @onnx/onnx/pull/7074"},
+ {"cast_INT4_to_FLOAT16", "Skipped until onnxruntime/cmake/external/onnx points to onnx 1.19 which should include @onnx/onnx/pull/7074"},
+ {"cast_INT4_to_INT8", "Skipped until onnxruntime/cmake/external/onnx points to onnx 1.19 which should include @onnx/onnx/pull/7074"},
+ {"cast_UINT4_to_FLOAT", "Skipped until onnxruntime/cmake/external/onnx points to onnx 1.19 which should include @onnx/onnx/pull/7074"},
+ {"cast_UINT4_to_FLOAT16", "Skipped until onnxruntime/cmake/external/onnx points to onnx 1.19 which should include @onnx/onnx/pull/7074"},
+ {"cast_UINT4_to_UINT8", "Skipped until onnxruntime/cmake/external/onnx points to onnx 1.19 which should include @onnx/onnx/pull/7074"},
{"loop13_seq", "Creation of empty sequences is currently not supported in the test runner"},
{"sequence_insert_at_front", "shape mismatch, expect {4} got {3}"},
{"cast_FLOAT_to_BFLOAT16", "expect uint16 got bfloat16"},
diff --git a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc
index 384adb5916cc1..68d4f3559504a 100644
--- a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc
+++ b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc
@@ -1,4 +1,4 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
+// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include
@@ -57,7 +57,7 @@ void TestCastOp(gsl::span input,
const BaseTester::DimsVariant& dimensions,
OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess,
const std::string& expected_failure_string = "",
- int opset = 13,
+ int opset = 21,
Saturate saturate = Saturate::None) {
OpTester test("Cast", opset);
test.AddAttribute("to", utils::ToTensorProtoElementType());
@@ -207,6 +207,1036 @@ TEST(CastOpTest, ToString) {
TestCastOp(gsl::make_span(int_16_input), gsl::make_span(int_string_data), shape);
}
+TEST(CastOpTest, Int4x2ToInt8) {
+ // GIVEN
+ const std::vector shape{2, 2, 2};
+ const std::vector int4x2_input = {
+ Int4x2(-8, 7), // boundary values
+ Int4x2(0, -1), // zero and negative
+ Int4x2(3, -5), // positive and negative
+ Int4x2(6, 2) // both positive
+ };
+
+ const std::vector expected_int8_output = {-8, 7, 0, -1, 3, -5, 6, 2};
+
+ // WHEN, THEN
+ TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_int8_output), shape);
+}
+
+TEST(CastOpTest, Int4x2ToUInt8) {
+ // GIVEN
+ const std::vector shape{2, 2, 2};
+ const std::vector int4x2_input = {
+ Int4x2(-8, 7), // boundary values
+ Int4x2(0, -1), // zero and negative
+ Int4x2(3, -5), // positive and negative
+ Int4x2(6, 2) // both positive
+ };
+
+ // Negative values will be cast to their unsigned representation
+ const std::vector expected_uint8_output = {248, 7, 0, UINT8_MAX, 3, 251, 6, 2};
+
+ // WHEN, THEN
+ TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_uint8_output), shape);
+}
+
+TEST(CastOpTest, Int4x2ToInt16) {
+ // GIVEN
+ const std::vector shape{2, 2, 2};
+ const std::vector int4x2_input = {
+ Int4x2(-8, 7), // boundary values
+ Int4x2(0, -1), // zero and negative
+ Int4x2(3, -5), // positive and negative
+ Int4x2(6, 2) // both positive
+ };
+
+ const std::vector expected_int16_output = {-8, 7, 0, -1, 3, -5, 6, 2};
+
+ // WHEN, THEN
+ TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_int16_output), shape);
+}
+
+TEST(CastOpTest, Int4x2ToUInt16) {
+ // GIVEN
+ const std::vector shape{2, 2, 2};
+ const std::vector int4x2_input = {
+ Int4x2(-8, 7), // boundary values
+ Int4x2(0, -1), // zero and negative
+ Int4x2(3, -5), // positive and negative
+ Int4x2(6, 2) // both positive
+ };
+
+ // Negative values will be cast to their unsigned representation
+ const std::vector expected_uint16_output = {65528, 7, 0, UINT16_MAX, 3, 65531, 6, 2};
+
+ // WHEN, THEN
+ TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_uint16_output), shape);
+}
+
+TEST(CastOpTest, Int4x2ToInt32) {
+ // GIVEN
+ const std::vector shape{2, 2, 2};
+ const std::vector int4x2_input = {
+ Int4x2(-8, 7), // boundary values
+ Int4x2(0, -1), // zero and negative
+ Int4x2(3, -5), // positive and negative
+ Int4x2(6, 2) // both positive
+ };
+
+ const std::vector expected_int32_output = {-8, 7, 0, -1, 3, -5, 6, 2};
+
+ // WHEN, THEN
+ TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_int32_output), shape);
+}
+
+TEST(CastOpTest, Int4x2ToUInt32) {
+ // GIVEN
+ const std::vector shape{2, 2, 2};
+ const std::vector int4x2_input = {
+ Int4x2(-8, 7), // boundary values
+ Int4x2(0, -1), // zero and negative
+ Int4x2(3, -5), // positive and negative
+ Int4x2(6, 2) // both positive
+ };
+
+ // Negative values will be cast to their unsigned representation
+ const std::vector expected_uint32_output = {4294967288, 7, 0, UINT32_MAX, 3, 4294967291, 6, 2};
+
+ // WHEN, THEN
+ TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_uint32_output), shape);
+}
+
+TEST(CastOpTest, Int4x2ToInt32OddNumberOfElements) {
+ // GIVEN
+ const std::vector odd_shape{5};
+ const std::vector odd_input = {
+ Int4x2(-8, 7), // boundary values
+ Int4x2(0, -1), // zero and negative
+ Int4x2(3, 0),
+ };
+
+ const std::vector expected_odd_output = {-8, 7, 0, -1, 3};
+
+ // WHEN, THEN
+ TestCastOp(gsl::make_span(odd_input), gsl::make_span(expected_odd_output), odd_shape);
+}
+
+TEST(CastOpTest, Int4x2ToInt64) {
+ // GIVEN
+ const std::vector shape{2, 2, 2};
+ const std::vector int4x2_input = {
+ Int4x2(-8, 7), // boundary values
+ Int4x2(0, -1), // zero and negative
+ Int4x2(3, -5), // positive and negative
+ Int4x2(6, 2) // both positive
+ };
+
+ const std::vector expected_int64_output = {-8, 7, 0, -1, 3, -5, 6, 2};
+
+ // WHEN, THEN
+ TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_int64_output), shape);
+}
+
+TEST(CastOpTest, Int4x2ToUInt64) {
+ // GIVEN
+ const std::vector shape{2, 2, 2};
+ const std::vector int4x2_input = {
+ Int4x2(-8, 7), // boundary values
+ Int4x2(0, -1), // zero and negative
+ Int4x2(3, -5), // positive and negative
+ Int4x2(6, 2) // both positive
+ };
+
+ // Negative values will be cast to their unsigned representation
+ const std::vector expected_uint64_output = {18446744073709551608ULL, 7, 0, UINT64_MAX, 3, 18446744073709551611ULL, 6, 2};
+
+ // WHEN, THEN
+ TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_uint64_output), shape);
+}
+
+TEST(CastOpTest, UInt4x2ToUInt8) {
+ // GIVEN
+ const std::vector shape{2, 2, 2};
+ const std::vector uint4x2_input = {
+ UInt4x2(0, 15), // boundary values
+ UInt4x2(1, 14), // small and large
+ UInt4x2(7, 8), // middle values
+ UInt4x2(3, 12) // mixed values
+ };
+
+ const std::vector expected_uint8_output = {0, 15, 1, 14, 7, 8, 3, 12};
+
+ // WHEN, THEN
+ TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_uint8_output), shape);
+}
+
+TEST(CastOpTest, UInt4x2ToInt8) {
+ // GIVEN
+ const std::vector shape{2, 2, 2};
+ const std::vector uint4x2_input = {
+ UInt4x2(0, 15), // boundary values
+ UInt4x2(1, 14), // small and large
+ UInt4x2(7, 8), // middle values
+ UInt4x2(3, 12) // mixed values
+ };
+
+ const std::vector expected_int8_output = {0, 15, 1, 14, 7, 8, 3, 12};
+
+ // WHEN, THEN
+ TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_int8_output), shape);
+}
+
+TEST(CastOpTest, UInt4x2ToUInt16) {
+ // GIVEN
+ const std::vector shape{2, 2, 2};
+ const std::vector uint4x2_input = {
+ UInt4x2(0, 15), // boundary values
+ UInt4x2(1, 14), // small and large
+ UInt4x2(7, 8), // middle values
+ UInt4x2(3, 12) // mixed values
+ };
+
+ const std::vector expected_uint16_output = {0, 15, 1, 14, 7, 8, 3, 12};
+
+ // WHEN, THEN
+ TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_uint16_output), shape);
+}
+
+TEST(CastOpTest, UInt4x2ToInt16) {
+ // GIVEN
+ const std::vector shape{2, 2, 2};
+ const std::vector uint4x2_input = {
+ UInt4x2(0, 15), // boundary values
+ UInt4x2(1, 14), // small and large
+ UInt4x2(7, 8), // middle values
+ UInt4x2(3, 12) // mixed values
+ };
+
+ const std::vector expected_int16_output = {0, 15, 1, 14, 7, 8, 3, 12};
+
+ // WHEN, THEN
+ TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_int16_output), shape);
+}
+
+TEST(CastOpTest, UInt4x2ToUInt32) {
+ // GIVEN
+ const std::vector shape{2, 2, 2};
+ const std::vector uint4x2_input = {
+ UInt4x2(0, 15), // boundary values
+ UInt4x2(1, 14), // small and large
+ UInt4x2(7, 8), // middle values
+ UInt4x2(3, 12) // mixed values
+ };
+
+ const std::vector expected_uint32_output = {0, 15, 1, 14, 7, 8, 3, 12};
+
+ // WHEN, THEN
+ TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_uint32_output), shape);
+}
+
+TEST(CastOpTest, UInt4x2ToInt32) {
+ // GIVEN
+ const std::vector shape{2, 2, 2};
+ const std::vector uint4x2_input = {
+ UInt4x2(0, 15), // boundary values
+ UInt4x2(1, 14), // small and large
+ UInt4x2(7, 8), // middle values
+ UInt4x2(3, 12) // mixed values
+ };
+
+ const std::vector expected_int32_output = {0, 15, 1, 14, 7, 8, 3, 12};
+
+ // WHEN, THEN
+ TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_int32_output), shape);
+}
+
+TEST(CastOpTest, UInt4x2ToUInt64) {
+ // GIVEN
+ const std::vector shape{2, 2, 2};
+ const std::vector uint4x2_input = {
+ UInt4x2(0, 15), // boundary values
+ UInt4x2(1, 14), // small and large
+ UInt4x2(7, 8), // middle values
+ UInt4x2(3, 12) // mixed values
+ };
+
+ const std::vector expected_uint64_output = {0, 15, 1, 14, 7, 8, 3, 12};
+
+ // WHEN, THEN
+ TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_uint64_output), shape);
+}
+
+TEST(CastOpTest, UInt4x2ToInt64) {
+ // GIVEN
+ const std::vector shape{2, 2, 2};
+ const std::vector uint4x2_input = {
+ UInt4x2(0, 15), // boundary values
+ UInt4x2(1, 14), // small and large
+ UInt4x2(7, 8), // middle values
+ UInt4x2(3, 12) // mixed values
+ };
+
+ const std::vector expected_int64_output = {0, 15, 1, 14, 7, 8, 3, 12};
+
+ // WHEN, THEN
+ TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_int64_output), shape);
+}
+
+TEST(CastOpTest, Int4x2ToBool) {
+ // GIVEN
+ const std::vector shape{2, 2, 2};
+ const std::vector int4x2_input = {
+ Int4x2(0, -1), // zero and non-zero
+ Int4x2(7, 0), // non-zero and zero
+ Int4x2(-8, 3), // both non-zero
+ Int4x2(0, 0) // both zero
+ };
+
+ const bool bool_output[] = {false, true, true, false, true, true, false, false};
+ const gsl::span expected_bool_output_span(bool_output);
+
+ // WHEN, THEN
+ TestCastOp(gsl::make_span(int4x2_input), expected_bool_output_span, shape);
+}
+
+TEST(CastOpTest, UInt4x2ToBool) {
+ // GIVEN
+ const std::vector shape{2, 2, 2};
+ const std::vector uint4x2_input = {
+ UInt4x2(0, 1), // zero and non-zero
+ UInt4x2(15, 0), // non-zero and zero
+ UInt4x2(8, 7), // both non-zero
+ UInt4x2(0, 0) // both zero
+ };
+
+ const bool bool_output[] = {false, true, true, false, true, true, false, false};
+ const gsl::span expected_bool_output_span(bool_output);
+
+ // WHEN, THEN
+ TestCastOp(gsl::make_span(uint4x2_input), expected_bool_output_span, shape);
+}
+
+TEST(CastOpTest, Int4x2ToFloat) {
+ // GIVEN
+ const std::vector shape{2, 2, 2};
+ const std::vector int4x2_input = {
+ Int4x2(1, 2), // two 4-bit int elements: lower = 1, upper = 2
+ Int4x2(-3, -4),
+ Int4x2(5, -6),
+ Int4x2(-8, 7)};
+
+ const std::vector expected_float_output = {1.0f, 2.0f, -3.0f, -4.0f, 5.0f, -6.0f, -8.0f, 7.0f};
+
+ // WHEN, THEN
+ TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_float_output), shape);
+}
+
+TEST(CastOpTest, UInt4x2ToFloat) {
+ // GIVEN
+ const std::vector shape{2, 2, 2};
+ const std::vector uint4x2_input = {
+ UInt4x2(0, 1),
+ UInt4x2(2, 3),
+ UInt4x2(7, 8),
+ UInt4x2(14, 15)};
+
+ const std::vector expected_float_output = {0.0f, 1.0f, 2.0f, 3.0f, 7.0f, 8.0f, 14.0f, 15.0f};
+
+ // WHEN, THEN
+ TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_float_output), shape);
+}
+
+TEST(CastOpTest, Int4x2ToDouble) {
+ // GIVEN
+ const std::vector shape{2, 2, 2};
+ const std::vector int4x2_input = {
+ Int4x2(-8, 7), // boundary values
+ Int4x2(0, -3), // zero and negative
+ Int4x2(4, -2), // positive and negative
+ Int4x2(1, 6) // both positive
+ };
+
+ const std::vector expected_double_output = {-8.0, 7.0, 0.0, -3.0, 4.0, -2.0, 1.0, 6.0};
+
+ // WHEN, THEN
+ TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_double_output), shape);
+}
+
+TEST(CastOpTest, UInt4x2ToDouble) {
+ // GIVEN
+ const std::vector shape{2, 2, 2};
+ const std::vector uint4x2_input = {
+ UInt4x2(0, 15), // boundary values
+ UInt4x2(1, 14), // small and large
+ UInt4x2(7, 8), // middle values
+ UInt4x2(3, 12) // mixed values
+ };
+
+ const std::vector expected_double_output = {0.0, 15.0, 1.0, 14.0, 7.0, 8.0, 3.0, 12.0};
+
+ // WHEN, THEN
+ TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_double_output), shape);
+}
+
+TEST(CastOpTest, Int4x2ToMLFloat16) {
+ // GIVEN
+ const std::vector shape{2, 2, 2};
+ const std::vector int4x2_input = {
+ Int4x2(-8, 7),
+ Int4x2(0, -1),
+ Int4x2(3, -5),
+ Int4x2(6, 2)};
+
+ const std::vector expected_float16_output =
+ CastedValues(
+ gsl::make_span(
+ std::vector{-8.0f, 7.0f, 0.0f, -1.0f, 3.0f, -5.0f, 6.0f, 2.0f}));
+
+ // WHEN, THEN
+ TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_float16_output), shape);
+}
+
+TEST(CastOpTest, UInt4x2ToMLFloat16) {
+ // GIVEN
+ const std::vector shape{2, 2, 2};
+ const std::vector uint4x2_input = {
+ UInt4x2(0, 15),
+ UInt4x2(1, 14),
+ UInt4x2(7, 8),
+ UInt4x2(3, 12)};
+
+ const std::vector expected_float16_output =
+ CastedValues(
+ gsl::make_span(
+ std::vector{0.0f, 15.0f, 1.0f, 14.0f, 7.0f, 8.0f, 3.0f, 12.0f}));
+
+ // WHEN, THEN
+ TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_float16_output), shape);
+}
+
+TEST(CastOpTest, Int4x2ToBFloat16) {
+ // GIVEN
+ const std::vector shape{2, 2, 2};
+ const std::vector int4x2_input = {
+ Int4x2(-8, 7),
+ Int4x2(0, -1),
+ Int4x2(3, -5),
+ Int4x2(6, 2)};
+
+ const std::vector expected_bfloat16_output =
+ CastedValues(
+ gsl::make_span(
+ std::vector{-8.0f, 7.0f, 0.0f, -1.0f, 3.0f, -5.0f, 6.0f, 2.0f}));
+
+ // WHEN, THEN
+ TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_bfloat16_output), shape);
+}
+
+TEST(CastOpTest, UInt4x2ToBFloat16) {
+ // GIVEN
+ const std::vector shape{2, 2, 2};
+ const std::vector uint4x2_input = {
+ UInt4x2(0, 15),
+ UInt4x2(1, 14),
+ UInt4x2(7, 8),
+ UInt4x2(3, 12)};
+
+ const std::vector expected_bfloat16_output =
+ CastedValues(
+ gsl::make_span(
+ std::vector{0.0f, 15.0f, 1.0f, 14.0f, 7.0f, 8.0f, 3.0f, 12.0f}));
+
+ // WHEN, THEN
+ TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_bfloat16_output), shape);
+}
+
+TEST(CastOpTest, Int4x2ToString) {
+ // GIVEN
+ const std::vector