diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 7596ab7592b25..d90280a5a3357 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -69,8 +69,8 @@ The **OpSet Version** column uses the following notation:
|BitwiseOr|*in* A:**T**
*in* B:**T**
*out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|BitwiseXor|*in* A:**T**
*in* B:**T**
*out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|BlackmanWindow|*in* size:**T1**
*out* output:**T2**|17+|**T1** = tensor(int32), tensor(int64)
**T2** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
-|Cast|*in* input:**T1**
*out* output:**T2**|25+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
-|||24|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
+|Cast|*in* input:**T1**
*out* output:**T2**|25+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(float8e8m0), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(float8e8m0), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
+|||24|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(float8e8m0), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(float8e8m0), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
|||23|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
|||[21, 22]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
|||[19, 20]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
diff --git a/include/onnxruntime/core/common/float8.h b/include/onnxruntime/core/common/float8.h
index 3afd1b0ebd767..9c342a346ee0b 100644
--- a/include/onnxruntime/core/common/float8.h
+++ b/include/onnxruntime/core/common/float8.h
@@ -701,7 +701,17 @@ struct Float8E8M0 {
static constexpr ORT_HOST_DEVICE FromBitsT FromBits() { return FromBitsT(); }
constexpr ORT_HOST_DEVICE Float8E8M0(unsigned char bits, FromBitsT) : val(bits) {}
- inline explicit ORT_HOST_DEVICE Float8E8M0(float v, bool saturate = true) {
+ /// Rounding modes for Float8E8M0 conversion from float.
+ /// These correspond to the ONNX Cast op's round_mode attribute for float8e8m0.
+ /// See: https://github.com/onnx/onnx/blob/main/onnx/numpy_helper.py (to_float8e8m0)
+ enum class RoundMode : uint8_t {
+ Up, // Ceiling: always round up to next power of 2 when not exact (default).
+ Down, // Floor: always truncate to lower power of 2.
+ Nearest, // Round to nearest power of 2; ties round to higher power (round-half-up).
+ };
+
+ inline explicit ORT_HOST_DEVICE Float8E8M0(float v, bool saturate = true,
+ RoundMode round_mode = RoundMode::Up) {
uint32_t b;
std::memcpy(&b, &v, sizeof(b));
@@ -756,34 +766,74 @@ struct Float8E8M0 {
return;
}
- // Denormalized float32: value = 2^(-126) * (mantissa / 2^23)
- // The largest subnormal is ~2^(-126) * (1 - 2^-23), which should round to 2^(-126) = val 1.
- // The midpoint between 2^(-127) and 2^(-126) is 1.5 * 2^(-127).
- // Subnormals with value >= midpoint round up to 2^(-126) (val=1), others to 2^(-127) (val=0).
- // Midpoint in subnormal mantissa: 0x00600000 (mantissa >= 0.75 * 2^23 means value >= 1.5 * 2^-127).
+ // Denormalized float32: value = 2^(-126) * (mantissa / 2^23), range (0, 2^(-126)).
+ // E8M0 can represent 2^(-127) (val=0) and 2^(-126) (val=1). For nearest rounding,
+ // the midpoint is 1.5 * 2^(-127), which is mantissa 0x600000. Ties round up.
if (exponent == 0) {
- if (saturate) {
- if (mantissa >= 0x00600000) {
- val = 0x01; // Round up to 2^(-126)
- } else {
- val = 0x00; // Round down to 2^(-127)
- }
+ // Subnormals with mantissa < 0x400000 have value < E8M0_MIN (2^-127) and
+ // cannot be represented. Without saturation they map to NaN.
+ // Subnormals with mantissa >= 0x400000 have value >= E8M0_MIN, so they
+ // round to val=0 or val=1, both valid E8M0 values.
+ if (!saturate && mantissa < 0x00400000) {
+ val = 0xFF; // NaN: x < E8M0_MIN is not representable without saturation
+ return;
+ }
+ bool round_up;
+ switch (round_mode) {
+ case RoundMode::Up:
+ // Ceiling: round up only when value > 2^(-127). Denorm mantissa == 0x400000
+ // is exactly 2^(-127) (val=0), so it must NOT round up.
+ round_up = (mantissa > 0x00400000);
+ break;
+ case RoundMode::Down:
+ // Floor: always keep val=0 (2^(-127)), never increment.
+ round_up = false;
+ break;
+ case RoundMode::Nearest:
+ default:
+ round_up = (mantissa >= 0x00600000);
+ break;
+ }
+ if (round_up) {
+ val = 0x01; // 2^(-126)
} else {
- val = 0xFF; // NaN (subnormals are below E8M0 min for saturate=false)
+ val = 0x00; // 2^(-127)
}
return;
}
- // Normal float32: value is 2^(exponent - 127) * (1 + mantissa/2^23)
- // We need to round to the nearest power of 2.
- // Round half up: round to next power of 2 when mantissa >= 0.5
- // (i.e., when the float value is >= 1.5 * nearest lower power of 2)
- // This aligns with the OCP Microscaling Formats (MX) spec for E8M0 scaling factors.
- if (mantissa >= 0x00400000) { // >= 0.5
+ // Normal float32: value is 2^(exponent - 127) * (1 + mantissa/2^23).
+ // Values with exponent=254 and mantissa>0 are in (2^127, 2^128). Since 2^128
+ // is not representable in E8M0 (val 255 = NaN), without saturation these
+ // values cannot be rounded to any valid E8M0 value and must become NaN.
+ if (!saturate && exponent == 0xFE && mantissa != 0) {
+ val = 0xFF; // NaN: x > E8M0_MAX is not representable without saturation
+ return;
+ }
+ // Round to the nearest power of 2 using the ONNX semantics:
+ // Up (ceiling): round up when the float is not exactly a power of 2 (mantissa > 0).
+ // Down (floor): never round up; always keep the lower exponent.
+ // Nearest: G bit (bit 22) determines direction -- round up when mantissa >= 0x400000.
+ // For normal floats lsb of exponent is always considered 1, so ties
+ // round to the higher power of 2 (round-half-up).
+ bool round_up;
+ switch (round_mode) {
+ case RoundMode::Up:
+ round_up = (mantissa > 0);
+ break;
+ case RoundMode::Down:
+ round_up = false;
+ break;
+ case RoundMode::Nearest:
+ default:
+ round_up = (mantissa >= 0x00400000);
+ break;
+ }
+ if (round_up) {
exponent += 1;
}
- // After rounding, exponent may overflow
+ // After rounding, exponent may overflow.
if (exponent > 0xFE) {
if (saturate) {
val = 0xFE; // Largest finite: 2^127
diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc
index 38d5c95813be2..1d5a7c63228b3 100644
--- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc
+++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc
@@ -29,18 +29,27 @@
namespace onnxruntime {
namespace {
-// Define a type list that extends AllIRv10 with INT2 types, but without Float4
+// Define a type list that extends AllIRv10 with INT2 types, but without Float4.
// Float4E2M1x2 doesn't support all the casting operations that other types do,
// so we don't include it here for the Cast operator.
-using AllIRv10WithInt2 =
+using AllIRv10WithInt2Base =
boost::mp11::mp_push_back<
element_type_lists::AllIRv10,
UInt2x4,
Int2x4>;
+
+#if !defined(DISABLE_FLOAT8_TYPES)
+// Float8E8M0 was added in opset 24 (IR v12), so include it in the full type list
+// but use the base list (without Float8E8M0) for pre-opset-24 kernel registrations.
+using AllIRv10WithInt2 =
+ boost::mp11::mp_push_back;
+#else
+using AllIRv10WithInt2 = AllIRv10WithInt2Base;
+#endif
} // namespace
namespace op_kernel_type_control {
-// we're using one set of types for all opsets of Cast
+// Type list for all opsets of Cast (includes Float8E8M0 for runtime dispatch).
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, Cast, Input, 0,
AllIRv10WithInt2);
@@ -64,6 +73,17 @@ using EnabledSrcTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExecu
using EnabledDstTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain,
Cast, Output, 0);
+// Pre-opset-24 type lists (without Float8E8M0) for kernel registration TypeConstraints.
+// Float8E8M0 was introduced in opset 24.
+#if !defined(DISABLE_FLOAT8_TYPES)
+using EnabledSrcTypesPreOpset24 = boost::mp11::mp_remove;
+using EnabledDstTypesPreOpset24 = boost::mp11::mp_remove;
+#else
+// When float8 types are disabled, Float8E8M0 is not in the type lists, so no removal needed.
+using EnabledSrcTypesPreOpset24 = EnabledSrcTypes;
+using EnabledDstTypesPreOpset24 = EnabledDstTypes;
+#endif
+
template
using IsOrtFloat16Type = boost::mp11::mp_contains, T>;
@@ -958,7 +978,8 @@ class Cast final : public OpKernel {
if (saturate == 0 && (to != ONNX_NAMESPACE::TensorProto::FLOAT8E4M3FN &&
to != ONNX_NAMESPACE::TensorProto::FLOAT8E4M3FNUZ &&
to != ONNX_NAMESPACE::TensorProto::FLOAT8E5M2 &&
- to != ONNX_NAMESPACE::TensorProto::FLOAT8E5M2FNUZ)) {
+ to != ONNX_NAMESPACE::TensorProto::FLOAT8E5M2FNUZ &&
+ to != ONNX_NAMESPACE::TensorProto::FLOAT8E8M0)) {
ORT_THROW("Attribute saturate is only used for cast to float 8 types.");
}
#else
@@ -967,6 +988,27 @@ class Cast final : public OpKernel {
}
#endif
saturate_ = saturate == 1;
+
+ // round_mode only applies for casting to float8e8m0 (introduced in opset 24)
+ std::string round_mode_str = info.GetAttrOrDefault("round_mode", std::string("up"));
+#if !defined(DISABLE_FLOAT8_TYPES)
+ if (round_mode_str == "up") {
+ round_mode_ = Float8E8M0::RoundMode::Up;
+ } else if (round_mode_str == "down") {
+ round_mode_ = Float8E8M0::RoundMode::Down;
+ } else if (round_mode_str == "nearest") {
+ round_mode_ = Float8E8M0::RoundMode::Nearest;
+ } else {
+ ORT_THROW("Attribute round_mode must be 'up', 'down', or 'nearest'.");
+ }
+ if (round_mode_ != Float8E8M0::RoundMode::Up && to != ONNX_NAMESPACE::TensorProto::FLOAT8E8M0) {
+ ORT_THROW("Attribute round_mode is only used for cast to float8e8m0.");
+ }
+#else
+ if (round_mode_str != "up") {
+ ORT_THROW("Attribute round_mode is only used for cast to float8e8m0.");
+ }
+#endif
}
Status Compute(OpKernelContext* context) const override;
@@ -974,6 +1016,9 @@ class Cast final : public OpKernel {
private:
ONNX_NAMESPACE::TensorProto_DataType to_;
bool saturate_;
+#if !defined(DISABLE_FLOAT8_TYPES)
+ Float8E8M0::RoundMode round_mode_{Float8E8M0::RoundMode::Up};
+#endif
};
template
@@ -1020,6 +1065,43 @@ struct SrcDispatcherNoSat {
}
};
+// Dispatcher for casting any source type to Float8E8M0 with round_mode and saturate support.
+// This bypasses the generic TensorCaster/TensorCasterNoSat templates to thread round_mode through.
+template
+struct CastToE8M0Dispatcher {
+ void operator()(const OpKernelContext&, const TensorShape& shape, const Tensor& src, Tensor& dst,
+ bool saturate, Float8E8M0::RoundMode round_mode) {
+ const auto shape_size = narrow(shape.Size());
+ auto* out_data = dst.MutableData();
+
+ if constexpr (IsOrtInt4Type::value) {
+ const auto* in_data = src.Data();
+ for (std::ptrdiff_t i = 0; i < shape_size; ++i) {
+ auto val = in_data[i >> 1].GetElem(i & 0x1);
+ out_data[i] = Float8E8M0(static_cast(val), saturate, round_mode);
+ }
+ } else if constexpr (IsOrtInt2Type::value) {
+ const auto* in_data = src.Data();
+ for (std::ptrdiff_t i = 0; i < shape_size; ++i) {
+ auto val = in_data[i >> 2].GetElem(i & 0x3);
+ out_data[i] = Float8E8M0(static_cast(val), saturate, round_mode);
+ }
+ } else if constexpr (std::is_same_v) {
+ const auto* in_data = src.Data();
+ for (std::ptrdiff_t i = 0; i < shape_size; ++i) {
+ float float_val;
+ CastFromString(in_data[i], float_val);
+ out_data[i] = Float8E8M0(float_val, saturate, round_mode);
+ }
+ } else {
+ const auto* in_data = src.Data();
+ for (std::ptrdiff_t i = 0; i < shape_size; ++i) {
+ out_data[i] = Float8E8M0(static_cast(in_data[i]), saturate, round_mode);
+ }
+ }
+ }
+};
+
#endif
Status Cast::Compute(OpKernelContext* context) const {
@@ -1040,6 +1122,14 @@ Status Cast::Compute(OpKernelContext* context) const {
}
#if !defined(DISABLE_FLOAT8_TYPES)
+ // Float8E8M0 destination needs special handling for round_mode support.
+ // Dispatch directly to avoid threading round_mode through the TensorCaster templates.
+ if (to_ == ONNX_NAMESPACE::TensorProto::FLOAT8E8M0) {
+ utils::MLTypeCallDispatcherFromTypeList dispatcher{from};
+ dispatcher.Invoke(*context, shape, *X, *Y, saturate_, round_mode_);
+ return Status::OK();
+ }
+
if (saturate_) {
#endif
utils::MLTypeCallDispatcherFromTypeList dispatcher{from};
@@ -1063,8 +1153,8 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
6,
12,
KernelDefBuilder()
- .TypeConstraint("T1", BuildKernelDefConstraintsFromTypeList())
- .TypeConstraint("T2", BuildKernelDefConstraintsFromTypeList())
+ .TypeConstraint("T1", BuildKernelDefConstraintsFromTypeList())
+ .TypeConstraint("T2", BuildKernelDefConstraintsFromTypeList())
.MayInplace(0, 0), // allocation planner will check input and output sizes match before inplacing
Cast);
@@ -1073,8 +1163,8 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
13,
18,
KernelDefBuilder()
- .TypeConstraint("T1", BuildKernelDefConstraintsFromTypeList())
- .TypeConstraint("T2", BuildKernelDefConstraintsFromTypeList())
+ .TypeConstraint("T1", BuildKernelDefConstraintsFromTypeList())
+ .TypeConstraint("T2", BuildKernelDefConstraintsFromTypeList())
.MayInplace(0, 0), // allocation planner will check input and output sizes match before inplacing
Cast);
@@ -1083,8 +1173,8 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
19,
20,
KernelDefBuilder()
- .TypeConstraint("T1", BuildKernelDefConstraintsFromTypeList())
- .TypeConstraint("T2", BuildKernelDefConstraintsFromTypeList())
+ .TypeConstraint("T1", BuildKernelDefConstraintsFromTypeList())
+ .TypeConstraint("T2", BuildKernelDefConstraintsFromTypeList())
.MayInplace(0, 0), // allocation planner will check input and output sizes match before inplacing
Cast);
@@ -1094,8 +1184,8 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
21,
22,
KernelDefBuilder()
- .TypeConstraint("T1", BuildKernelDefConstraintsFromTypeList())
- .TypeConstraint("T2", BuildKernelDefConstraintsFromTypeList())
+ .TypeConstraint("T1", BuildKernelDefConstraintsFromTypeList())
+ .TypeConstraint("T2", BuildKernelDefConstraintsFromTypeList())
.MayInplace(0, 0), // allocation planner will check input and output sizes match before inplacing
Cast);
@@ -1106,11 +1196,12 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
23,
23,
KernelDefBuilder()
- .TypeConstraint("T1", BuildKernelDefConstraintsFromTypeList())
- .TypeConstraint("T2", BuildKernelDefConstraintsFromTypeList())
+ .TypeConstraint("T1", BuildKernelDefConstraintsFromTypeList())
+ .TypeConstraint("T2", BuildKernelDefConstraintsFromTypeList())
.MayInplace(0, 0), // allocation planner will check input and output sizes match before inplacing
Cast);
+// Opset 24 added support for float8e8m0.
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
Cast,
24,
diff --git a/onnxruntime/test/framework/float8e8m0_test.cc b/onnxruntime/test/framework/float8e8m0_test.cc
index e27694355c3e9..d02d6a16d8d40 100644
--- a/onnxruntime/test/framework/float8e8m0_test.cc
+++ b/onnxruntime/test/framework/float8e8m0_test.cc
@@ -121,9 +121,9 @@ TEST(Float8E8M0_Tests, Rounding) {
Float8E8M0 val_1_5(1.5f);
EXPECT_EQ(val_1_5.val, 128); // 2^1 = 2.0
- // 1.25 should round down to 1.0 (mantissa < 0.5)
+ // 1.25 rounds up to 2.0 with default "up" (ceiling) mode since mantissa != 0
Float8E8M0 val_1_25(1.25f);
- EXPECT_EQ(val_1_25.val, 127); // 2^0 = 1.0
+ EXPECT_EQ(val_1_25.val, 128); // 2^1 = 2.0
// 3.0 should round up to 4.0 (mantissa = 0.5)
Float8E8M0 val_3(3.0f);
@@ -257,9 +257,25 @@ TEST(Float8E8M0_Tests, SubnormalRounding) {
Float8E8M0 from_small_subnorm(small_subnorm, true);
EXPECT_EQ(from_small_subnorm.val, 0x00); // Rounds down to 2^(-127)
- // With saturate=false, all subnormals produce NaN
+ // In nearest mode, subnormals below the midpoint between 2^-127 and 2^-126
+ // round down to 2^-127. Up mode would round this value to 2^-126.
+ uint32_t below_midpoint_bits = 0x00500000;
+ float below_midpoint;
+ std::memcpy(&below_midpoint, &below_midpoint_bits, sizeof(float));
+ Float8E8M0 nearest_below_midpoint(below_midpoint, true, Float8E8M0::RoundMode::Nearest);
+ EXPECT_EQ(nearest_below_midpoint.val, 0x00);
+
+ // The exact midpoint ties upward.
+ uint32_t midpoint_bits = 0x00600000;
+ float midpoint;
+ std::memcpy(&midpoint, &midpoint_bits, sizeof(float));
+ Float8E8M0 nearest_midpoint(midpoint, true, Float8E8M0::RoundMode::Nearest);
+ EXPECT_EQ(nearest_midpoint.val, 0x01);
+
+ // With saturate=false, subnormals within E8M0 range are still valid positive values,
+ // so they round normally (not NaN). Largest subnormal rounds up to 2^(-126).
Float8E8M0 subnorm_nosat(largest_subnorm, false);
- EXPECT_TRUE(subnorm_nosat.IsNaN());
+ EXPECT_EQ(subnorm_nosat.val, 0x01);
}
TEST(Float8E8M0_Tests, BatchConversionSpecialValues) {
diff --git a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc
index 4481cf36554cd..0e14bc59a09c9 100644
--- a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc
+++ b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc
@@ -3127,5 +3127,318 @@ TEST(CastOpTest, CopyCpuTensor_SubByteTypes_DistinctBuffers) {
}
}
+#if !defined(DISABLE_FLOAT8_TYPES)
+
+float FloatFromBits(uint32_t bits) {
+ float value;
+ std::memcpy(&value, &bits, sizeof(float));
+ return value;
+}
+
+template
+void TestCastToFloat8E8M0(gsl::span input,
+ gsl::span output,
+ const std::vector& shape,
+ Saturate saturate = Saturate::None,
+ const std::string& round_mode = "",
+ int opset = 24,
+ OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess,
+ const std::string& expected_failure_string = "") {
+ OpTester test("Cast", opset);
+ test.AddAttribute("to", utils::ToTensorProtoElementType());
+ test.AddInput("input", shape, input.data(), input.size());
+ test.AddOutput("output", shape, output.data(), output.size());
+ if (saturate != Saturate::None) {
+ test.AddAttribute("saturate", saturate == Saturate::True ? 1 : 0);
+ }
+ if (!round_mode.empty()) {
+ test.AddAttribute("round_mode", round_mode);
+ }
+
+ std::vector> execution_providers;
+ execution_providers.emplace_back(DefaultCpuExecutionProvider());
+ test.ConfigEps(std::move(execution_providers))
+ .Config(expect_result, expected_failure_string)
+ .RunWithConfig();
+}
+
+template
+void TestCastFromFloat8E8M0(gsl::span input,
+ gsl::span output,
+ const std::vector& shape,
+ int opset = 24) {
+ OpTester test("Cast", opset);
+ test.AddAttribute("to", utils::ToTensorProtoElementType());
+ test.AddInput("input", shape, input.data(), input.size());
+ test.AddOutput("output", shape, output.data(), output.size());
+
+ std::vector> execution_providers;
+ execution_providers.emplace_back(DefaultCpuExecutionProvider());
+ test.ConfigEps(std::move(execution_providers))
+ .RunWithConfig();
+}
+
+TEST(CastOpTest, FloatToFloat8E8M0_Saturate) {
+ const std::vector shape{8};
+ // Test values: NaN, -1, positive, 1.5 (tie), -Inf, +Inf, 0, very small
+ const std::vector input = {NAN, -1.0f, 4.0f, 1.5f,
+ -std::numeric_limits::infinity(),
+ std::numeric_limits::infinity(),
+ 0.0f, 1e-39f};
+
+ std::vector expected;
+ expected.reserve(input.size());
+ for (float v : input) {
+ expected.emplace_back(Float8E8M0(v, true));
+ }
+
+ TestCastToFloat8E8M0(gsl::make_span(input), gsl::make_span(expected), shape, Saturate::True);
+}
+
+TEST(CastOpTest, FloatToFloat8E8M0_NoSaturate) {
+ const std::vector shape{6};
+ const std::vector input = {NAN, -1.0f, 4.0f, 0.0f,
+ -std::numeric_limits::infinity(),
+ std::numeric_limits::infinity()};
+
+ std::vector expected;
+ expected.reserve(input.size());
+ for (float v : input) {
+ expected.emplace_back(Float8E8M0(v, false));
+ }
+
+ TestCastToFloat8E8M0(gsl::make_span(input), gsl::make_span(expected), shape, Saturate::False);
+}
+
+TEST(CastOpTest, FloatToFloat8E8M0_RoundModeUp) {
+ const std::vector shape{4};
+ // "up" mode is ceiling: always round up to the next power of 2 when not exact.
+ // Exact powers of 2 (mantissa == 0) are unchanged; all others round up.
+ // 1.5 (mantissa != 0) -> 2^1 = 2.0 (val=128)
+ // 3.0 (mantissa != 0) -> 2^2 = 4.0 (val=129)
+ // 1.3 (mantissa != 0) -> 2^1 = 2.0 (val=128) [ceiling, not round-half-up]
+ // 2.5 (mantissa != 0) -> 2^2 = 4.0 (val=129) [ceiling, not round-half-up]
+ const std::vector input = {1.5f, 3.0f, 1.3f, 2.5f};
+ const std::vector expected = {
+ Float8E8M0(128, Float8E8M0::FromBits()), // 1.5 -> 2.0
+ Float8E8M0(129, Float8E8M0::FromBits()), // 3.0 -> 4.0
+ Float8E8M0(128, Float8E8M0::FromBits()), // 1.3 -> 2.0
+ Float8E8M0(129, Float8E8M0::FromBits()), // 2.5 -> 4.0
+ };
+ TestCastToFloat8E8M0(gsl::make_span(input), gsl::make_span(expected), shape, Saturate::True, "up");
+}
+
+TEST(CastOpTest, FloatToFloat8E8M0_RoundModeDown) {
+ const std::vector shape{4};
+ // "down" mode is floor: always truncate to the lower power of 2, never increment.
+ // All non-power-of-2 values keep the lower exponent regardless of their fractional part.
+ // 1.5 -> 2^0 = 1.0 (val=127) [floor, not round-half-down]
+ // 3.0 -> 2^1 = 2.0 (val=128) [floor]
+ // 1.7 -> 2^0 = 1.0 (val=127) [floor, not round-half-down -- 1.7 > midpoint but still floors]
+ // 2.5 -> 2^1 = 2.0 (val=128) [floor]
+ const std::vector input = {1.5f, 3.0f, 1.7f, 2.5f};
+ const std::vector expected = {
+ Float8E8M0(127, Float8E8M0::FromBits()), // 1.5 -> 1.0
+ Float8E8M0(128, Float8E8M0::FromBits()), // 3.0 -> 2.0
+ Float8E8M0(127, Float8E8M0::FromBits()), // 1.7 -> 1.0
+ Float8E8M0(128, Float8E8M0::FromBits()), // 2.5 -> 2.0
+ };
+ TestCastToFloat8E8M0(gsl::make_span(input), gsl::make_span(expected), shape, Saturate::True, "down");
+}
+
+TEST(CastOpTest, FloatToFloat8E8M0_RoundModeNearest) {
+ const std::vector shape{4};
+ // "nearest" mode: round to nearest power of 2; ties (exactly halfway) round up.
+ // Decision threshold: guard bit (bit 22 of mantissa, representing 0.5 of fractional part).
+ // 1.5 -> midpoint (mantissa=0x400000) -> round up -> 2^1 = 2.0 (val=128)
+ // 3.0 -> midpoint (mantissa=0x400000) -> round up -> 2^2 = 4.0 (val=129)
+ // 1.3 -> closer to 1.0 (mantissa=0x266666 < 0x400000) -> 2^0 = 1.0 (val=127)
+ // 2.5 -> closer to 2.0 (mantissa=0x200000 < 0x400000) -> 2^1 = 2.0 (val=128)
+ // Note: "nearest" differs from "up" for 1.3 and 2.5 (ceiling would give val=128/129).
+ const std::vector input = {1.5f, 3.0f, 1.3f, 2.5f};
+ const std::vector expected = {
+ Float8E8M0(128, Float8E8M0::FromBits()), // 1.5 -> 2.0 (tie, rounds up)
+ Float8E8M0(129, Float8E8M0::FromBits()), // 3.0 -> 4.0 (tie, rounds up)
+ Float8E8M0(127, Float8E8M0::FromBits()), // 1.3 -> 1.0 (nearer to 1.0)
+ Float8E8M0(128, Float8E8M0::FromBits()), // 2.5 -> 2.0 (nearer to 2.0)
+ };
+ TestCastToFloat8E8M0(gsl::make_span(input), gsl::make_span(expected), shape, Saturate::True, "nearest");
+}
+
+TEST(CastOpTest, FloatToFloat8E8M0_RoundModeNearestSubnormal) {
+ const std::vector shape{4};
+ const std::vector input = {
+ FloatFromBits(0x00400000), // 2^-127, exact E8M0 minimum
+ FloatFromBits(0x00500000), // below midpoint, differs from round_mode="up"
+ FloatFromBits(0x00600000), // exact midpoint, ties upward
+ FloatFromBits(0x007FFFFF), // largest float32 subnormal
+ };
+ const std::vector expected = {
+ Float8E8M0(0x00, Float8E8M0::FromBits()),
+ Float8E8M0(0x00, Float8E8M0::FromBits()),
+ Float8E8M0(0x01, Float8E8M0::FromBits()),
+ Float8E8M0(0x01, Float8E8M0::FromBits()),
+ };
+ TestCastToFloat8E8M0(gsl::make_span(input), gsl::make_span(expected), shape, Saturate::True, "nearest");
+}
+
+TEST(CastOpTest, Float8E8M0ToFloat) {
+ const std::vector shape{4};
+ const std::vector input = {
+ Float8E8M0(127, Float8E8M0::FromBits()), // 2^0 = 1.0
+ Float8E8M0(128, Float8E8M0::FromBits()), // 2^1 = 2.0
+ Float8E8M0(0, Float8E8M0::FromBits()), // 2^-127 (smallest)
+ Float8E8M0(254, Float8E8M0::FromBits()), // 2^127 (largest finite)
+ };
+
+ std::vector expected;
+ expected.reserve(input.size());
+ for (const auto& v : input) {
+ expected.emplace_back(v.ToFloat());
+ }
+
+ TestCastFromFloat8E8M0(gsl::make_span(input), gsl::make_span(expected), shape);
+}
+
+TEST(CastOpTest, MLFloat16ToFloat8E8M0) {
+ const std::vector shape{4};
+ const std::vector float_values = {1.0f, 2.0f, 4.0f, 0.5f};
+ std::vector input;
+ input.reserve(float_values.size());
+ for (float v : float_values) {
+ input.emplace_back(MLFloat16(v));
+ }
+
+ std::vector expected;
+ expected.reserve(float_values.size());
+ for (float v : float_values) {
+ expected.emplace_back(Float8E8M0(v, true));
+ }
+
+ TestCastToFloat8E8M0(gsl::make_span(input), gsl::make_span(expected), shape, Saturate::True);
+}
+
+TEST(CastOpTest, DoubleToFloat8E8M0) {
+ const std::vector shape{4};
+ const std::vector input = {1.0, 2.0, 4.0, 0.5};
+
+ std::vector expected;
+ expected.reserve(input.size());
+ for (double v : input) {
+ expected.emplace_back(Float8E8M0(static_cast(v), true));
+ }
+
+ TestCastToFloat8E8M0(gsl::make_span(input), gsl::make_span(expected), shape, Saturate::True);
+}
+
+TEST(CastOpTest, Int32ToFloat8E8M0) {
+ const std::vector shape{4};
+ const std::vector input = {1, 2, 4, 8};
+
+ std::vector expected;
+ expected.reserve(input.size());
+ for (int32_t v : input) {
+ expected.emplace_back(Float8E8M0(static_cast(v), true));
+ }
+
+ TestCastToFloat8E8M0(gsl::make_span(input), gsl::make_span(expected), shape, Saturate::True);
+}
+
+TEST(CastOpTest, Float8E8M0ToDouble) {
+ const std::vector shape{3};
+ const std::vector input = {
+ Float8E8M0(127, Float8E8M0::FromBits()), // 1.0
+ Float8E8M0(128, Float8E8M0::FromBits()), // 2.0
+ Float8E8M0(126, Float8E8M0::FromBits()), // 0.5
+ };
+
+ std::vector expected;
+ expected.reserve(input.size());
+ for (const auto& v : input) {
+ expected.emplace_back(static_cast(v.ToFloat()));
+ }
+
+ TestCastFromFloat8E8M0(gsl::make_span(input), gsl::make_span(expected), shape);
+}
+
+// Edge-case tests verifying E8M0 conversion table.
+// E8M0 format: val 0 = 2^(-127) (E8M0_MIN), val 254 = 2^127 (E8M0_MAX), val 255 = NaN.
+// E8M0 cannot represent zero, negative values, or infinity.
+
+TEST(CastOpTest, FloatToFloat8E8M0_SaturateUp_EdgeCases) {
+ // With saturate=true and round_mode="up", out-of-range values clamp to the nearest boundary.
+ const std::vector shape{8};
+ const std::vector input = {
+ 0.0f, // x = 0 (below E8M0 range)
+ -0.0f, // x = -0 (behavior unspecified per spec)
+ NAN, // x = NaN
+ std::numeric_limits::infinity(), // x = +Inf
+ -std::numeric_limits::infinity(), // x = -Inf
+ 3e38f, // x > E8M0_MAX (~1.76 * 2^127)
+ 1e-39f, // x < E8M0_MIN (positive subnormal)
+ -1.0f, // x < 0
+ };
+ const std::vector expected = {
+ Float8E8M0(0, Float8E8M0::FromBits()), // 0 → E8M0_MIN (val 0)
+ Float8E8M0(0, Float8E8M0::FromBits()), // -0 → E8M0_MIN (val 0, treated same as +0)
+ Float8E8M0(255, Float8E8M0::FromBits()), // NaN → NaN (val 255)
+ Float8E8M0(254, Float8E8M0::FromBits()), // +Inf → E8M0_MAX (val 254)
+ Float8E8M0(0, Float8E8M0::FromBits()), // -Inf → E8M0_MIN (val 0, negative saturated)
+ Float8E8M0(254, Float8E8M0::FromBits()), // 3e38 → E8M0_MAX (val 254, overflow saturated)
+ Float8E8M0(0, Float8E8M0::FromBits()), // 1e-39 -> E8M0_MIN (val 0): ceiling of x < E8M0_MIN is E8M0_MIN
+ Float8E8M0(0, Float8E8M0::FromBits()), // -1 → E8M0_MIN (val 0, negative saturated)
+ };
+ TestCastToFloat8E8M0(gsl::make_span(input), gsl::make_span(expected), shape, Saturate::True, "up");
+}
+
+TEST(CastOpTest, FloatToFloat8E8M0_NonSaturateNearest_EdgeCases) {
+ // With saturate=false and round_mode="nearest", all values outside [E8M0_MIN, E8M0_MAX] become NaN.
+ const std::vector shape{8};
+ const std::vector input = {
+ 0.0f, // x = 0 (not representable)
+ -0.0f, // x = -0 (not representable)
+ NAN, // x = NaN
+ std::numeric_limits::infinity(), // x = +Inf (not representable)
+ -std::numeric_limits::infinity(), // x = -Inf (not representable)
+ 3e38f, // x > E8M0_MAX (not representable)
+ 1e-39f, // x < E8M0_MIN: subnormal below 2^(-127) -> NaN
+ -1.0f, // x < 0 (not representable)
+ };
+ const std::vector expected(8, Float8E8M0(255, Float8E8M0::FromBits())); // all -> NaN
+ TestCastToFloat8E8M0(gsl::make_span(input), gsl::make_span(expected), shape, Saturate::False, "nearest");
+}
+
+TEST(CastOpTest, FloatToFloat8E8M0_NonSaturate_AboveMax) {
+ // With saturate=false, any value strictly above E8M0_MAX (2^127) gives NaN,
+ // since 2^128 is not representable in E8M0 (val 255 = NaN).
+ const std::vector shape{2};
+ const std::vector input = {
+ 2e38f, // ~1.18 * 2^127, strictly above E8M0_MAX -> NaN
+ 3e38f, // ~1.76 * 2^127, strictly above E8M0_MAX -> NaN
+ };
+ const std::vector expected(2, Float8E8M0(255, Float8E8M0::FromBits())); // all -> NaN
+ TestCastToFloat8E8M0(gsl::make_span(input), gsl::make_span(expected), shape, Saturate::False, "nearest");
+}
+
+TEST(CastOpTest, FloatToFloat8E8M0_Saturate_AboveMax) {
+ // With saturate=true, values above E8M0_MAX clamp to E8M0_MAX.
+ const std::vector shape{2};
+ const std::vector input = {2e38f, 3e38f};
+ const std::vector expected(2, Float8E8M0(254, Float8E8M0::FromBits())); // E8M0_MAX
+ TestCastToFloat8E8M0(gsl::make_span(input), gsl::make_span(expected), shape, Saturate::True, "up");
+}
+
+TEST(CastOpTest, FloatToFloat8E8M0_ExactMax) {
+ // Exactly E8M0_MAX (2^127) is representable in all modes.
+ const std::vector shape{1};
+ // 2^127 = 1.7014118e+38
+ const float e8m0_max = Float8E8M0(254, Float8E8M0::FromBits()).ToFloat();
+ const std::vector input = {e8m0_max};
+ const std::vector expected = {Float8E8M0(254, Float8E8M0::FromBits())};
+ TestCastToFloat8E8M0(gsl::make_span(input), gsl::make_span(expected), shape, Saturate::False, "nearest");
+}
+
+#endif // !defined(DISABLE_FLOAT8_TYPES)
+
} // namespace test
} // namespace onnxruntime