diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 7596ab7592b25..d90280a5a3357 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -69,8 +69,8 @@ The **OpSet Version** column uses the following notation: |BitwiseOr|*in* A:**T**
*in* B:**T**
*out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |BitwiseXor|*in* A:**T**
*in* B:**T**
*out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |BlackmanWindow|*in* size:**T1**
*out* output:**T2**|17+|**T1** = tensor(int32), tensor(int64)
**T2** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Cast|*in* input:**T1**
*out* output:**T2**|25+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| -|||24|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| +|Cast|*in* input:**T1**
*out* output:**T2**|25+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(float8e8m0), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(float8e8m0), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| +|||24|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(float8e8m0), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(float8e8m0), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| |||23|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| |||[21, 22]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| |||[19, 20]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| diff --git a/include/onnxruntime/core/common/float8.h b/include/onnxruntime/core/common/float8.h index 3afd1b0ebd767..9c342a346ee0b 100644 --- a/include/onnxruntime/core/common/float8.h +++ b/include/onnxruntime/core/common/float8.h @@ -701,7 +701,17 @@ struct Float8E8M0 { static constexpr ORT_HOST_DEVICE FromBitsT FromBits() { return FromBitsT(); } constexpr ORT_HOST_DEVICE Float8E8M0(unsigned char bits, FromBitsT) : val(bits) {} - inline explicit ORT_HOST_DEVICE Float8E8M0(float v, bool saturate = true) { + /// Rounding modes for Float8E8M0 conversion from float. + /// These correspond to the ONNX Cast op's round_mode attribute for float8e8m0. + /// See: https://github.com/onnx/onnx/blob/main/onnx/numpy_helper.py (to_float8e8m0) + enum class RoundMode : uint8_t { + Up, // Ceiling: always round up to next power of 2 when not exact (default). + Down, // Floor: always truncate to lower power of 2. + Nearest, // Round to nearest power of 2; ties round to higher power (round-half-up). + }; + + inline explicit ORT_HOST_DEVICE Float8E8M0(float v, bool saturate = true, + RoundMode round_mode = RoundMode::Up) { uint32_t b; std::memcpy(&b, &v, sizeof(b)); @@ -756,34 +766,74 @@ struct Float8E8M0 { return; } - // Denormalized float32: value = 2^(-126) * (mantissa / 2^23) - // The largest subnormal is ~2^(-126) * (1 - 2^-23), which should round to 2^(-126) = val 1. - // The midpoint between 2^(-127) and 2^(-126) is 1.5 * 2^(-127). - // Subnormals with value >= midpoint round up to 2^(-126) (val=1), others to 2^(-127) (val=0). - // Midpoint in subnormal mantissa: 0x00600000 (mantissa >= 0.75 * 2^23 means value >= 1.5 * 2^-127). + // Denormalized float32: value = 2^(-126) * (mantissa / 2^23), range (0, 2^(-126)). + // E8M0 can represent 2^(-127) (val=0) and 2^(-126) (val=1). For nearest rounding, + // the midpoint is 1.5 * 2^(-127), which is mantissa 0x600000. Ties round up. if (exponent == 0) { - if (saturate) { - if (mantissa >= 0x00600000) { - val = 0x01; // Round up to 2^(-126) - } else { - val = 0x00; // Round down to 2^(-127) - } + // Subnormals with mantissa < 0x400000 have value < E8M0_MIN (2^-127) and + // cannot be represented. Without saturation they map to NaN. + // Subnormals with mantissa >= 0x400000 have value >= E8M0_MIN, so they + // round to val=0 or val=1, both valid E8M0 values. + if (!saturate && mantissa < 0x00400000) { + val = 0xFF; // NaN: x < E8M0_MIN is not representable without saturation + return; + } + bool round_up; + switch (round_mode) { + case RoundMode::Up: + // Ceiling: round up only when value > 2^(-127). Denorm mantissa == 0x400000 + // is exactly 2^(-127) (val=0), so it must NOT round up. + round_up = (mantissa > 0x00400000); + break; + case RoundMode::Down: + // Floor: always keep val=0 (2^(-127)), never increment. + round_up = false; + break; + case RoundMode::Nearest: + default: + round_up = (mantissa >= 0x00600000); + break; + } + if (round_up) { + val = 0x01; // 2^(-126) } else { - val = 0xFF; // NaN (subnormals are below E8M0 min for saturate=false) + val = 0x00; // 2^(-127) } return; } - // Normal float32: value is 2^(exponent - 127) * (1 + mantissa/2^23) - // We need to round to the nearest power of 2. - // Round half up: round to next power of 2 when mantissa >= 0.5 - // (i.e., when the float value is >= 1.5 * nearest lower power of 2) - // This aligns with the OCP Microscaling Formats (MX) spec for E8M0 scaling factors. - if (mantissa >= 0x00400000) { // >= 0.5 + // Normal float32: value is 2^(exponent - 127) * (1 + mantissa/2^23). + // Values with exponent=254 and mantissa>0 are in (2^127, 2^128). Since 2^128 + // is not representable in E8M0 (val 255 = NaN), without saturation these + // values cannot be rounded to any valid E8M0 value and must become NaN. + if (!saturate && exponent == 0xFE && mantissa != 0) { + val = 0xFF; // NaN: x > E8M0_MAX is not representable without saturation + return; + } + // Round to the nearest power of 2 using the ONNX semantics: + // Up (ceiling): round up when the float is not exactly a power of 2 (mantissa > 0). + // Down (floor): never round up; always keep the lower exponent. + // Nearest: G bit (bit 22) determines direction -- round up when mantissa >= 0x400000. + // For normal floats lsb of exponent is always considered 1, so ties + // round to the higher power of 2 (round-half-up). + bool round_up; + switch (round_mode) { + case RoundMode::Up: + round_up = (mantissa > 0); + break; + case RoundMode::Down: + round_up = false; + break; + case RoundMode::Nearest: + default: + round_up = (mantissa >= 0x00400000); + break; + } + if (round_up) { exponent += 1; } - // After rounding, exponent may overflow + // After rounding, exponent may overflow. if (exponent > 0xFE) { if (saturate) { val = 0xFE; // Largest finite: 2^127 diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 38d5c95813be2..1d5a7c63228b3 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -29,18 +29,27 @@ namespace onnxruntime { namespace { -// Define a type list that extends AllIRv10 with INT2 types, but without Float4 +// Define a type list that extends AllIRv10 with INT2 types, but without Float4. // Float4E2M1x2 doesn't support all the casting operations that other types do, // so we don't include it here for the Cast operator. -using AllIRv10WithInt2 = +using AllIRv10WithInt2Base = boost::mp11::mp_push_back< element_type_lists::AllIRv10, UInt2x4, Int2x4>; + +#if !defined(DISABLE_FLOAT8_TYPES) +// Float8E8M0 was added in opset 24 (IR v12), so include it in the full type list +// but use the base list (without Float8E8M0) for pre-opset-24 kernel registrations. +using AllIRv10WithInt2 = + boost::mp11::mp_push_back; +#else +using AllIRv10WithInt2 = AllIRv10WithInt2Base; +#endif } // namespace namespace op_kernel_type_control { -// we're using one set of types for all opsets of Cast +// Type list for all opsets of Cast (includes Float8E8M0 for runtime dispatch). ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS( kCpuExecutionProvider, kOnnxDomain, Cast, Input, 0, AllIRv10WithInt2); @@ -64,6 +73,17 @@ using EnabledSrcTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExecu using EnabledDstTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain, Cast, Output, 0); +// Pre-opset-24 type lists (without Float8E8M0) for kernel registration TypeConstraints. +// Float8E8M0 was introduced in opset 24. +#if !defined(DISABLE_FLOAT8_TYPES) +using EnabledSrcTypesPreOpset24 = boost::mp11::mp_remove; +using EnabledDstTypesPreOpset24 = boost::mp11::mp_remove; +#else +// When float8 types are disabled, Float8E8M0 is not in the type lists, so no removal needed. +using EnabledSrcTypesPreOpset24 = EnabledSrcTypes; +using EnabledDstTypesPreOpset24 = EnabledDstTypes; +#endif + template using IsOrtFloat16Type = boost::mp11::mp_contains, T>; @@ -958,7 +978,8 @@ class Cast final : public OpKernel { if (saturate == 0 && (to != ONNX_NAMESPACE::TensorProto::FLOAT8E4M3FN && to != ONNX_NAMESPACE::TensorProto::FLOAT8E4M3FNUZ && to != ONNX_NAMESPACE::TensorProto::FLOAT8E5M2 && - to != ONNX_NAMESPACE::TensorProto::FLOAT8E5M2FNUZ)) { + to != ONNX_NAMESPACE::TensorProto::FLOAT8E5M2FNUZ && + to != ONNX_NAMESPACE::TensorProto::FLOAT8E8M0)) { ORT_THROW("Attribute saturate is only used for cast to float 8 types."); } #else @@ -967,6 +988,27 @@ class Cast final : public OpKernel { } #endif saturate_ = saturate == 1; + + // round_mode only applies for casting to float8e8m0 (introduced in opset 24) + std::string round_mode_str = info.GetAttrOrDefault("round_mode", std::string("up")); +#if !defined(DISABLE_FLOAT8_TYPES) + if (round_mode_str == "up") { + round_mode_ = Float8E8M0::RoundMode::Up; + } else if (round_mode_str == "down") { + round_mode_ = Float8E8M0::RoundMode::Down; + } else if (round_mode_str == "nearest") { + round_mode_ = Float8E8M0::RoundMode::Nearest; + } else { + ORT_THROW("Attribute round_mode must be 'up', 'down', or 'nearest'."); + } + if (round_mode_ != Float8E8M0::RoundMode::Up && to != ONNX_NAMESPACE::TensorProto::FLOAT8E8M0) { + ORT_THROW("Attribute round_mode is only used for cast to float8e8m0."); + } +#else + if (round_mode_str != "up") { + ORT_THROW("Attribute round_mode is only used for cast to float8e8m0."); + } +#endif } Status Compute(OpKernelContext* context) const override; @@ -974,6 +1016,9 @@ class Cast final : public OpKernel { private: ONNX_NAMESPACE::TensorProto_DataType to_; bool saturate_; +#if !defined(DISABLE_FLOAT8_TYPES) + Float8E8M0::RoundMode round_mode_{Float8E8M0::RoundMode::Up}; +#endif }; template @@ -1020,6 +1065,43 @@ struct SrcDispatcherNoSat { } }; +// Dispatcher for casting any source type to Float8E8M0 with round_mode and saturate support. +// This bypasses the generic TensorCaster/TensorCasterNoSat templates to thread round_mode through. +template +struct CastToE8M0Dispatcher { + void operator()(const OpKernelContext&, const TensorShape& shape, const Tensor& src, Tensor& dst, + bool saturate, Float8E8M0::RoundMode round_mode) { + const auto shape_size = narrow(shape.Size()); + auto* out_data = dst.MutableData(); + + if constexpr (IsOrtInt4Type::value) { + const auto* in_data = src.Data(); + for (std::ptrdiff_t i = 0; i < shape_size; ++i) { + auto val = in_data[i >> 1].GetElem(i & 0x1); + out_data[i] = Float8E8M0(static_cast(val), saturate, round_mode); + } + } else if constexpr (IsOrtInt2Type::value) { + const auto* in_data = src.Data(); + for (std::ptrdiff_t i = 0; i < shape_size; ++i) { + auto val = in_data[i >> 2].GetElem(i & 0x3); + out_data[i] = Float8E8M0(static_cast(val), saturate, round_mode); + } + } else if constexpr (std::is_same_v) { + const auto* in_data = src.Data(); + for (std::ptrdiff_t i = 0; i < shape_size; ++i) { + float float_val; + CastFromString(in_data[i], float_val); + out_data[i] = Float8E8M0(float_val, saturate, round_mode); + } + } else { + const auto* in_data = src.Data(); + for (std::ptrdiff_t i = 0; i < shape_size; ++i) { + out_data[i] = Float8E8M0(static_cast(in_data[i]), saturate, round_mode); + } + } + } +}; + #endif Status Cast::Compute(OpKernelContext* context) const { @@ -1040,6 +1122,14 @@ Status Cast::Compute(OpKernelContext* context) const { } #if !defined(DISABLE_FLOAT8_TYPES) + // Float8E8M0 destination needs special handling for round_mode support. + // Dispatch directly to avoid threading round_mode through the TensorCaster templates. + if (to_ == ONNX_NAMESPACE::TensorProto::FLOAT8E8M0) { + utils::MLTypeCallDispatcherFromTypeList dispatcher{from}; + dispatcher.Invoke(*context, shape, *X, *Y, saturate_, round_mode_); + return Status::OK(); + } + if (saturate_) { #endif utils::MLTypeCallDispatcherFromTypeList dispatcher{from}; @@ -1063,8 +1153,8 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( 6, 12, KernelDefBuilder() - .TypeConstraint("T1", BuildKernelDefConstraintsFromTypeList()) - .TypeConstraint("T2", BuildKernelDefConstraintsFromTypeList()) + .TypeConstraint("T1", BuildKernelDefConstraintsFromTypeList()) + .TypeConstraint("T2", BuildKernelDefConstraintsFromTypeList()) .MayInplace(0, 0), // allocation planner will check input and output sizes match before inplacing Cast); @@ -1073,8 +1163,8 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( 13, 18, KernelDefBuilder() - .TypeConstraint("T1", BuildKernelDefConstraintsFromTypeList()) - .TypeConstraint("T2", BuildKernelDefConstraintsFromTypeList()) + .TypeConstraint("T1", BuildKernelDefConstraintsFromTypeList()) + .TypeConstraint("T2", BuildKernelDefConstraintsFromTypeList()) .MayInplace(0, 0), // allocation planner will check input and output sizes match before inplacing Cast); @@ -1083,8 +1173,8 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( 19, 20, KernelDefBuilder() - .TypeConstraint("T1", BuildKernelDefConstraintsFromTypeList()) - .TypeConstraint("T2", BuildKernelDefConstraintsFromTypeList()) + .TypeConstraint("T1", BuildKernelDefConstraintsFromTypeList()) + .TypeConstraint("T2", BuildKernelDefConstraintsFromTypeList()) .MayInplace(0, 0), // allocation planner will check input and output sizes match before inplacing Cast); @@ -1094,8 +1184,8 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( 21, 22, KernelDefBuilder() - .TypeConstraint("T1", BuildKernelDefConstraintsFromTypeList()) - .TypeConstraint("T2", BuildKernelDefConstraintsFromTypeList()) + .TypeConstraint("T1", BuildKernelDefConstraintsFromTypeList()) + .TypeConstraint("T2", BuildKernelDefConstraintsFromTypeList()) .MayInplace(0, 0), // allocation planner will check input and output sizes match before inplacing Cast); @@ -1106,11 +1196,12 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( 23, 23, KernelDefBuilder() - .TypeConstraint("T1", BuildKernelDefConstraintsFromTypeList()) - .TypeConstraint("T2", BuildKernelDefConstraintsFromTypeList()) + .TypeConstraint("T1", BuildKernelDefConstraintsFromTypeList()) + .TypeConstraint("T2", BuildKernelDefConstraintsFromTypeList()) .MayInplace(0, 0), // allocation planner will check input and output sizes match before inplacing Cast); +// Opset 24 added support for float8e8m0. ONNX_CPU_OPERATOR_VERSIONED_KERNEL( Cast, 24, diff --git a/onnxruntime/test/framework/float8e8m0_test.cc b/onnxruntime/test/framework/float8e8m0_test.cc index e27694355c3e9..d02d6a16d8d40 100644 --- a/onnxruntime/test/framework/float8e8m0_test.cc +++ b/onnxruntime/test/framework/float8e8m0_test.cc @@ -121,9 +121,9 @@ TEST(Float8E8M0_Tests, Rounding) { Float8E8M0 val_1_5(1.5f); EXPECT_EQ(val_1_5.val, 128); // 2^1 = 2.0 - // 1.25 should round down to 1.0 (mantissa < 0.5) + // 1.25 rounds up to 2.0 with default "up" (ceiling) mode since mantissa != 0 Float8E8M0 val_1_25(1.25f); - EXPECT_EQ(val_1_25.val, 127); // 2^0 = 1.0 + EXPECT_EQ(val_1_25.val, 128); // 2^1 = 2.0 // 3.0 should round up to 4.0 (mantissa = 0.5) Float8E8M0 val_3(3.0f); @@ -257,9 +257,25 @@ TEST(Float8E8M0_Tests, SubnormalRounding) { Float8E8M0 from_small_subnorm(small_subnorm, true); EXPECT_EQ(from_small_subnorm.val, 0x00); // Rounds down to 2^(-127) - // With saturate=false, all subnormals produce NaN + // In nearest mode, subnormals below the midpoint between 2^-127 and 2^-126 + // round down to 2^-127. Up mode would round this value to 2^-126. + uint32_t below_midpoint_bits = 0x00500000; + float below_midpoint; + std::memcpy(&below_midpoint, &below_midpoint_bits, sizeof(float)); + Float8E8M0 nearest_below_midpoint(below_midpoint, true, Float8E8M0::RoundMode::Nearest); + EXPECT_EQ(nearest_below_midpoint.val, 0x00); + + // The exact midpoint ties upward. + uint32_t midpoint_bits = 0x00600000; + float midpoint; + std::memcpy(&midpoint, &midpoint_bits, sizeof(float)); + Float8E8M0 nearest_midpoint(midpoint, true, Float8E8M0::RoundMode::Nearest); + EXPECT_EQ(nearest_midpoint.val, 0x01); + + // With saturate=false, subnormals within E8M0 range are still valid positive values, + // so they round normally (not NaN). Largest subnormal rounds up to 2^(-126). Float8E8M0 subnorm_nosat(largest_subnorm, false); - EXPECT_TRUE(subnorm_nosat.IsNaN()); + EXPECT_EQ(subnorm_nosat.val, 0x01); } TEST(Float8E8M0_Tests, BatchConversionSpecialValues) { diff --git a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc index 4481cf36554cd..0e14bc59a09c9 100644 --- a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc @@ -3127,5 +3127,318 @@ TEST(CastOpTest, CopyCpuTensor_SubByteTypes_DistinctBuffers) { } } +#if !defined(DISABLE_FLOAT8_TYPES) + +float FloatFromBits(uint32_t bits) { + float value; + std::memcpy(&value, &bits, sizeof(float)); + return value; +} + +template +void TestCastToFloat8E8M0(gsl::span input, + gsl::span output, + const std::vector& shape, + Saturate saturate = Saturate::None, + const std::string& round_mode = "", + int opset = 24, + OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, + const std::string& expected_failure_string = "") { + OpTester test("Cast", opset); + test.AddAttribute("to", utils::ToTensorProtoElementType()); + test.AddInput("input", shape, input.data(), input.size()); + test.AddOutput("output", shape, output.data(), output.size()); + if (saturate != Saturate::None) { + test.AddAttribute("saturate", saturate == Saturate::True ? 1 : 0); + } + if (!round_mode.empty()) { + test.AddAttribute("round_mode", round_mode); + } + + std::vector> execution_providers; + execution_providers.emplace_back(DefaultCpuExecutionProvider()); + test.ConfigEps(std::move(execution_providers)) + .Config(expect_result, expected_failure_string) + .RunWithConfig(); +} + +template +void TestCastFromFloat8E8M0(gsl::span input, + gsl::span output, + const std::vector& shape, + int opset = 24) { + OpTester test("Cast", opset); + test.AddAttribute("to", utils::ToTensorProtoElementType()); + test.AddInput("input", shape, input.data(), input.size()); + test.AddOutput("output", shape, output.data(), output.size()); + + std::vector> execution_providers; + execution_providers.emplace_back(DefaultCpuExecutionProvider()); + test.ConfigEps(std::move(execution_providers)) + .RunWithConfig(); +} + +TEST(CastOpTest, FloatToFloat8E8M0_Saturate) { + const std::vector shape{8}; + // Test values: NaN, -1, positive, 1.5 (tie), -Inf, +Inf, 0, very small + const std::vector input = {NAN, -1.0f, 4.0f, 1.5f, + -std::numeric_limits::infinity(), + std::numeric_limits::infinity(), + 0.0f, 1e-39f}; + + std::vector expected; + expected.reserve(input.size()); + for (float v : input) { + expected.emplace_back(Float8E8M0(v, true)); + } + + TestCastToFloat8E8M0(gsl::make_span(input), gsl::make_span(expected), shape, Saturate::True); +} + +TEST(CastOpTest, FloatToFloat8E8M0_NoSaturate) { + const std::vector shape{6}; + const std::vector input = {NAN, -1.0f, 4.0f, 0.0f, + -std::numeric_limits::infinity(), + std::numeric_limits::infinity()}; + + std::vector expected; + expected.reserve(input.size()); + for (float v : input) { + expected.emplace_back(Float8E8M0(v, false)); + } + + TestCastToFloat8E8M0(gsl::make_span(input), gsl::make_span(expected), shape, Saturate::False); +} + +TEST(CastOpTest, FloatToFloat8E8M0_RoundModeUp) { + const std::vector shape{4}; + // "up" mode is ceiling: always round up to the next power of 2 when not exact. + // Exact powers of 2 (mantissa == 0) are unchanged; all others round up. + // 1.5 (mantissa != 0) -> 2^1 = 2.0 (val=128) + // 3.0 (mantissa != 0) -> 2^2 = 4.0 (val=129) + // 1.3 (mantissa != 0) -> 2^1 = 2.0 (val=128) [ceiling, not round-half-up] + // 2.5 (mantissa != 0) -> 2^2 = 4.0 (val=129) [ceiling, not round-half-up] + const std::vector input = {1.5f, 3.0f, 1.3f, 2.5f}; + const std::vector expected = { + Float8E8M0(128, Float8E8M0::FromBits()), // 1.5 -> 2.0 + Float8E8M0(129, Float8E8M0::FromBits()), // 3.0 -> 4.0 + Float8E8M0(128, Float8E8M0::FromBits()), // 1.3 -> 2.0 + Float8E8M0(129, Float8E8M0::FromBits()), // 2.5 -> 4.0 + }; + TestCastToFloat8E8M0(gsl::make_span(input), gsl::make_span(expected), shape, Saturate::True, "up"); +} + +TEST(CastOpTest, FloatToFloat8E8M0_RoundModeDown) { + const std::vector shape{4}; + // "down" mode is floor: always truncate to the lower power of 2, never increment. + // All non-power-of-2 values keep the lower exponent regardless of their fractional part. + // 1.5 -> 2^0 = 1.0 (val=127) [floor, not round-half-down] + // 3.0 -> 2^1 = 2.0 (val=128) [floor] + // 1.7 -> 2^0 = 1.0 (val=127) [floor, not round-half-down -- 1.7 > midpoint but still floors] + // 2.5 -> 2^1 = 2.0 (val=128) [floor] + const std::vector input = {1.5f, 3.0f, 1.7f, 2.5f}; + const std::vector expected = { + Float8E8M0(127, Float8E8M0::FromBits()), // 1.5 -> 1.0 + Float8E8M0(128, Float8E8M0::FromBits()), // 3.0 -> 2.0 + Float8E8M0(127, Float8E8M0::FromBits()), // 1.7 -> 1.0 + Float8E8M0(128, Float8E8M0::FromBits()), // 2.5 -> 2.0 + }; + TestCastToFloat8E8M0(gsl::make_span(input), gsl::make_span(expected), shape, Saturate::True, "down"); +} + +TEST(CastOpTest, FloatToFloat8E8M0_RoundModeNearest) { + const std::vector shape{4}; + // "nearest" mode: round to nearest power of 2; ties (exactly halfway) round up. + // Decision threshold: guard bit (bit 22 of mantissa, representing 0.5 of fractional part). + // 1.5 -> midpoint (mantissa=0x400000) -> round up -> 2^1 = 2.0 (val=128) + // 3.0 -> midpoint (mantissa=0x400000) -> round up -> 2^2 = 4.0 (val=129) + // 1.3 -> closer to 1.0 (mantissa=0x266666 < 0x400000) -> 2^0 = 1.0 (val=127) + // 2.5 -> closer to 2.0 (mantissa=0x200000 < 0x400000) -> 2^1 = 2.0 (val=128) + // Note: "nearest" differs from "up" for 1.3 and 2.5 (ceiling would give val=128/129). + const std::vector input = {1.5f, 3.0f, 1.3f, 2.5f}; + const std::vector expected = { + Float8E8M0(128, Float8E8M0::FromBits()), // 1.5 -> 2.0 (tie, rounds up) + Float8E8M0(129, Float8E8M0::FromBits()), // 3.0 -> 4.0 (tie, rounds up) + Float8E8M0(127, Float8E8M0::FromBits()), // 1.3 -> 1.0 (nearer to 1.0) + Float8E8M0(128, Float8E8M0::FromBits()), // 2.5 -> 2.0 (nearer to 2.0) + }; + TestCastToFloat8E8M0(gsl::make_span(input), gsl::make_span(expected), shape, Saturate::True, "nearest"); +} + +TEST(CastOpTest, FloatToFloat8E8M0_RoundModeNearestSubnormal) { + const std::vector shape{4}; + const std::vector input = { + FloatFromBits(0x00400000), // 2^-127, exact E8M0 minimum + FloatFromBits(0x00500000), // below midpoint, differs from round_mode="up" + FloatFromBits(0x00600000), // exact midpoint, ties upward + FloatFromBits(0x007FFFFF), // largest float32 subnormal + }; + const std::vector expected = { + Float8E8M0(0x00, Float8E8M0::FromBits()), + Float8E8M0(0x00, Float8E8M0::FromBits()), + Float8E8M0(0x01, Float8E8M0::FromBits()), + Float8E8M0(0x01, Float8E8M0::FromBits()), + }; + TestCastToFloat8E8M0(gsl::make_span(input), gsl::make_span(expected), shape, Saturate::True, "nearest"); +} + +TEST(CastOpTest, Float8E8M0ToFloat) { + const std::vector shape{4}; + const std::vector input = { + Float8E8M0(127, Float8E8M0::FromBits()), // 2^0 = 1.0 + Float8E8M0(128, Float8E8M0::FromBits()), // 2^1 = 2.0 + Float8E8M0(0, Float8E8M0::FromBits()), // 2^-127 (smallest) + Float8E8M0(254, Float8E8M0::FromBits()), // 2^127 (largest finite) + }; + + std::vector expected; + expected.reserve(input.size()); + for (const auto& v : input) { + expected.emplace_back(v.ToFloat()); + } + + TestCastFromFloat8E8M0(gsl::make_span(input), gsl::make_span(expected), shape); +} + +TEST(CastOpTest, MLFloat16ToFloat8E8M0) { + const std::vector shape{4}; + const std::vector float_values = {1.0f, 2.0f, 4.0f, 0.5f}; + std::vector input; + input.reserve(float_values.size()); + for (float v : float_values) { + input.emplace_back(MLFloat16(v)); + } + + std::vector expected; + expected.reserve(float_values.size()); + for (float v : float_values) { + expected.emplace_back(Float8E8M0(v, true)); + } + + TestCastToFloat8E8M0(gsl::make_span(input), gsl::make_span(expected), shape, Saturate::True); +} + +TEST(CastOpTest, DoubleToFloat8E8M0) { + const std::vector shape{4}; + const std::vector input = {1.0, 2.0, 4.0, 0.5}; + + std::vector expected; + expected.reserve(input.size()); + for (double v : input) { + expected.emplace_back(Float8E8M0(static_cast(v), true)); + } + + TestCastToFloat8E8M0(gsl::make_span(input), gsl::make_span(expected), shape, Saturate::True); +} + +TEST(CastOpTest, Int32ToFloat8E8M0) { + const std::vector shape{4}; + const std::vector input = {1, 2, 4, 8}; + + std::vector expected; + expected.reserve(input.size()); + for (int32_t v : input) { + expected.emplace_back(Float8E8M0(static_cast(v), true)); + } + + TestCastToFloat8E8M0(gsl::make_span(input), gsl::make_span(expected), shape, Saturate::True); +} + +TEST(CastOpTest, Float8E8M0ToDouble) { + const std::vector shape{3}; + const std::vector input = { + Float8E8M0(127, Float8E8M0::FromBits()), // 1.0 + Float8E8M0(128, Float8E8M0::FromBits()), // 2.0 + Float8E8M0(126, Float8E8M0::FromBits()), // 0.5 + }; + + std::vector expected; + expected.reserve(input.size()); + for (const auto& v : input) { + expected.emplace_back(static_cast(v.ToFloat())); + } + + TestCastFromFloat8E8M0(gsl::make_span(input), gsl::make_span(expected), shape); +} + +// Edge-case tests verifying E8M0 conversion table. +// E8M0 format: val 0 = 2^(-127) (E8M0_MIN), val 254 = 2^127 (E8M0_MAX), val 255 = NaN. +// E8M0 cannot represent zero, negative values, or infinity. + +TEST(CastOpTest, FloatToFloat8E8M0_SaturateUp_EdgeCases) { + // With saturate=true and round_mode="up", out-of-range values clamp to the nearest boundary. + const std::vector shape{8}; + const std::vector input = { + 0.0f, // x = 0 (below E8M0 range) + -0.0f, // x = -0 (behavior unspecified per spec) + NAN, // x = NaN + std::numeric_limits::infinity(), // x = +Inf + -std::numeric_limits::infinity(), // x = -Inf + 3e38f, // x > E8M0_MAX (~1.76 * 2^127) + 1e-39f, // x < E8M0_MIN (positive subnormal) + -1.0f, // x < 0 + }; + const std::vector expected = { + Float8E8M0(0, Float8E8M0::FromBits()), // 0 → E8M0_MIN (val 0) + Float8E8M0(0, Float8E8M0::FromBits()), // -0 → E8M0_MIN (val 0, treated same as +0) + Float8E8M0(255, Float8E8M0::FromBits()), // NaN → NaN (val 255) + Float8E8M0(254, Float8E8M0::FromBits()), // +Inf → E8M0_MAX (val 254) + Float8E8M0(0, Float8E8M0::FromBits()), // -Inf → E8M0_MIN (val 0, negative saturated) + Float8E8M0(254, Float8E8M0::FromBits()), // 3e38 → E8M0_MAX (val 254, overflow saturated) + Float8E8M0(0, Float8E8M0::FromBits()), // 1e-39 -> E8M0_MIN (val 0): ceiling of x < E8M0_MIN is E8M0_MIN + Float8E8M0(0, Float8E8M0::FromBits()), // -1 → E8M0_MIN (val 0, negative saturated) + }; + TestCastToFloat8E8M0(gsl::make_span(input), gsl::make_span(expected), shape, Saturate::True, "up"); +} + +TEST(CastOpTest, FloatToFloat8E8M0_NonSaturateNearest_EdgeCases) { + // With saturate=false and round_mode="nearest", all values outside [E8M0_MIN, E8M0_MAX] become NaN. + const std::vector shape{8}; + const std::vector input = { + 0.0f, // x = 0 (not representable) + -0.0f, // x = -0 (not representable) + NAN, // x = NaN + std::numeric_limits::infinity(), // x = +Inf (not representable) + -std::numeric_limits::infinity(), // x = -Inf (not representable) + 3e38f, // x > E8M0_MAX (not representable) + 1e-39f, // x < E8M0_MIN: subnormal below 2^(-127) -> NaN + -1.0f, // x < 0 (not representable) + }; + const std::vector expected(8, Float8E8M0(255, Float8E8M0::FromBits())); // all -> NaN + TestCastToFloat8E8M0(gsl::make_span(input), gsl::make_span(expected), shape, Saturate::False, "nearest"); +} + +TEST(CastOpTest, FloatToFloat8E8M0_NonSaturate_AboveMax) { + // With saturate=false, any value strictly above E8M0_MAX (2^127) gives NaN, + // since 2^128 is not representable in E8M0 (val 255 = NaN). + const std::vector shape{2}; + const std::vector input = { + 2e38f, // ~1.18 * 2^127, strictly above E8M0_MAX -> NaN + 3e38f, // ~1.76 * 2^127, strictly above E8M0_MAX -> NaN + }; + const std::vector expected(2, Float8E8M0(255, Float8E8M0::FromBits())); // all -> NaN + TestCastToFloat8E8M0(gsl::make_span(input), gsl::make_span(expected), shape, Saturate::False, "nearest"); +} + +TEST(CastOpTest, FloatToFloat8E8M0_Saturate_AboveMax) { + // With saturate=true, values above E8M0_MAX clamp to E8M0_MAX. + const std::vector shape{2}; + const std::vector input = {2e38f, 3e38f}; + const std::vector expected(2, Float8E8M0(254, Float8E8M0::FromBits())); // E8M0_MAX + TestCastToFloat8E8M0(gsl::make_span(input), gsl::make_span(expected), shape, Saturate::True, "up"); +} + +TEST(CastOpTest, FloatToFloat8E8M0_ExactMax) { + // Exactly E8M0_MAX (2^127) is representable in all modes. + const std::vector shape{1}; + // 2^127 = 1.7014118e+38 + const float e8m0_max = Float8E8M0(254, Float8E8M0::FromBits()).ToFloat(); + const std::vector input = {e8m0_max}; + const std::vector expected = {Float8E8M0(254, Float8E8M0::FromBits())}; + TestCastToFloat8E8M0(gsl::make_span(input), gsl::make_span(expected), shape, Saturate::False, "nearest"); +} + +#endif // !defined(DISABLE_FLOAT8_TYPES) + } // namespace test } // namespace onnxruntime