Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ The **OpSet Version** column uses the following notation:
|BitwiseOr|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|BitwiseXor|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|BlackmanWindow|*in* size:**T1**<br> *out* output:**T2**|17+|**T1** = tensor(int32), tensor(int64)<br/> **T2** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Cast|*in* input:**T1**<br> *out* output:**T2**|25+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
|||24|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
|Cast|*in* input:**T1**<br> *out* output:**T2**|25+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(float8e8m0), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(float8e8m0), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
|||24|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(float8e8m0), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(float8e8m0), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
|||23|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
|||[21, 22]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
|||[19, 20]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
Expand Down
90 changes: 70 additions & 20 deletions include/onnxruntime/core/common/float8.h
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,17 @@ struct Float8E8M0 {
static constexpr ORT_HOST_DEVICE FromBitsT FromBits() { return FromBitsT(); }
constexpr ORT_HOST_DEVICE Float8E8M0(unsigned char bits, FromBitsT) : val(bits) {}

inline explicit ORT_HOST_DEVICE Float8E8M0(float v, bool saturate = true) {
/// Rounding modes for Float8E8M0 conversion from float.
/// These correspond to the ONNX Cast op's round_mode attribute for float8e8m0.
/// See: https://github.com/onnx/onnx/blob/main/onnx/numpy_helper.py (to_float8e8m0)
enum class RoundMode : uint8_t {
Up, // Ceiling: always round up to next power of 2 when not exact (default).
Down, // Floor: always truncate to lower power of 2.
Nearest, // Round to nearest power of 2; ties round to higher power (round-half-up).
};

inline explicit ORT_HOST_DEVICE Float8E8M0(float v, bool saturate = true,
RoundMode round_mode = RoundMode::Up) {
uint32_t b;
std::memcpy(&b, &v, sizeof(b));

Expand Down Expand Up @@ -756,34 +766,74 @@ struct Float8E8M0 {
return;
}

// Denormalized float32: value = 2^(-126) * (mantissa / 2^23)
// The largest subnormal is ~2^(-126) * (1 - 2^-23), which should round to 2^(-126) = val 1.
// The midpoint between 2^(-127) and 2^(-126) is 1.5 * 2^(-127).
// Subnormals with value >= midpoint round up to 2^(-126) (val=1), others to 2^(-127) (val=0).
// Midpoint in subnormal mantissa: 0x00600000 (mantissa >= 0.75 * 2^23 means value >= 1.5 * 2^-127).
// Denormalized float32: value = 2^(-126) * (mantissa / 2^23), range (0, 2^(-126)).
// E8M0 can represent 2^(-127) (val=0) and 2^(-126) (val=1). For nearest rounding,
// the midpoint is 1.5 * 2^(-127), which is mantissa 0x600000. Ties round up.
if (exponent == 0) {
if (saturate) {
if (mantissa >= 0x00600000) {
val = 0x01; // Round up to 2^(-126)
} else {
val = 0x00; // Round down to 2^(-127)
}
// Subnormals with mantissa < 0x400000 have value < E8M0_MIN (2^-127) and
// cannot be represented. Without saturation they map to NaN.
// Subnormals with mantissa >= 0x400000 have value >= E8M0_MIN, so they
// round to val=0 or val=1, both valid E8M0 values.
if (!saturate && mantissa < 0x00400000) {
val = 0xFF; // NaN: x < E8M0_MIN is not representable without saturation
return;
}
bool round_up;
switch (round_mode) {
case RoundMode::Up:
// Ceiling: round up only when value > 2^(-127). Denorm mantissa == 0x400000
// is exactly 2^(-127) (val=0), so it must NOT round up.
round_up = (mantissa > 0x00400000);
break;
case RoundMode::Down:
// Floor: always keep val=0 (2^(-127)), never increment.
round_up = false;
break;
case RoundMode::Nearest:
default:
round_up = (mantissa >= 0x00600000);
break;
}
if (round_up) {
val = 0x01; // 2^(-126)
} else {
val = 0xFF; // NaN (subnormals are below E8M0 min for saturate=false)
val = 0x00; // 2^(-127)
}
return;
}

// Normal float32: value is 2^(exponent - 127) * (1 + mantissa/2^23)
// We need to round to the nearest power of 2.
// Round half up: round to next power of 2 when mantissa >= 0.5
// (i.e., when the float value is >= 1.5 * nearest lower power of 2)
// This aligns with the OCP Microscaling Formats (MX) spec for E8M0 scaling factors.
if (mantissa >= 0x00400000) { // >= 0.5
// Normal float32: value is 2^(exponent - 127) * (1 + mantissa/2^23).
// Values with exponent=254 and mantissa>0 are in (2^127, 2^128). Since 2^128
// is not representable in E8M0 (val 255 = NaN), without saturation these
// values cannot be rounded to any valid E8M0 value and must become NaN.
if (!saturate && exponent == 0xFE && mantissa != 0) {
val = 0xFF; // NaN: x > E8M0_MAX is not representable without saturation
return;
}
// Round to the nearest power of 2 using the ONNX semantics:
// Up (ceiling): round up when the float is not exactly a power of 2 (mantissa > 0).
// Down (floor): never round up; always keep the lower exponent.
// Nearest: G bit (bit 22) determines direction -- round up when mantissa >= 0x400000.
// For normal floats lsb of exponent is always considered 1, so ties
// round to the higher power of 2 (round-half-up).
bool round_up;
switch (round_mode) {
case RoundMode::Up:
round_up = (mantissa > 0);
break;
case RoundMode::Down:
round_up = false;
break;
case RoundMode::Nearest:
default:
round_up = (mantissa >= 0x00400000);
break;
}
if (round_up) {
exponent += 1;
}

// After rounding, exponent may overflow
// After rounding, exponent may overflow.
if (exponent > 0xFE) {
if (saturate) {
val = 0xFE; // Largest finite: 2^127
Expand Down
Loading
Loading