Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ The **OpSet Version** column uses the following notation:
|BitwiseOr|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|BitwiseXor|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|BlackmanWindow|*in* size:**T1**<br> *out* output:**T2**|17+|**T1** = tensor(int32), tensor(int64)<br/> **T2** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Cast|*in* input:**T1**<br> *out* output:**T2**|25+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
|||24|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
|Cast|*in* input:**T1**<br> *out* output:**T2**|25+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(float8e8m0), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(float8e8m0), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
|||24|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(float8e8m0), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(float8e8m0), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
|||23|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
|||[21, 22]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
|||[19, 20]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
Expand Down
94 changes: 74 additions & 20 deletions include/onnxruntime/core/common/float8.h
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,17 @@ struct Float8E8M0 {
static constexpr ORT_HOST_DEVICE FromBitsT FromBits() { return FromBitsT(); }
constexpr ORT_HOST_DEVICE Float8E8M0(unsigned char bits, FromBitsT) : val(bits) {}

inline explicit ORT_HOST_DEVICE Float8E8M0(float v, bool saturate = true) {
/// Rounding modes for Float8E8M0 conversion from float.
/// These correspond to the ONNX Cast op's round_mode attribute for float8e8m0.
/// See: https://github.com/onnx/onnx/blob/main/onnx/numpy_helper.py (to_float8e8m0)
enum class RoundMode : uint8_t {
Up, // Ceiling: always round up to next power of 2 when not exact (default).
Down, // Floor: always truncate to lower power of 2.
Nearest, // Round to nearest power of 2; ties round to higher power (round-half-up).
};

inline explicit ORT_HOST_DEVICE Float8E8M0(float v, bool saturate = true,
RoundMode round_mode = RoundMode::Up) {
uint32_t b;
std::memcpy(&b, &v, sizeof(b));

Expand Down Expand Up @@ -756,34 +766,78 @@ struct Float8E8M0 {
return;
}

// Denormalized float32: value = 2^(-126) * (mantissa / 2^23)
// The largest subnormal is ~2^(-126) * (1 - 2^-23), which should round to 2^(-126) = val 1.
// The midpoint between 2^(-127) and 2^(-126) is 1.5 * 2^(-127).
// Subnormals with value >= midpoint round up to 2^(-126) (val=1), others to 2^(-127) (val=0).
// Midpoint in subnormal mantissa: 0x00600000 (mantissa >= 0.75 * 2^23 means value >= 1.5 * 2^-127).
// Denormalized float32: value = 2^(-126) * (mantissa / 2^23), range (0, 2^(-126)).
// E8M0 can represent 2^(-127) (val=0) and 2^(-126) (val=1). Round using the same
// G/R/S scheme as the ONNX reference (to_float8e8m0 in onnx/numpy_helper.py):
// G (guard) = bit 22 of mantissa; R (round) = bit 21; S (sticky) = bits 20:0.
// For subnormals, lsb of result exponent = 0, so "nearest" rounds up only when
// G=1 AND (R=1 OR S!=0), i.e. mantissa > 0x400000.
if (exponent == 0) {
if (saturate) {
if (mantissa >= 0x00600000) {
val = 0x01; // Round up to 2^(-126)
} else {
val = 0x00; // Round down to 2^(-127)
}
// Subnormals with mantissa < 0x400000 have value < E8M0_MIN (2^-127) and
// cannot be represented. Without saturation they map to NaN.
// Subnormals with mantissa >= 0x400000 have value >= E8M0_MIN, so they
// round to val=0 or val=1, both valid E8M0 values.
if (!saturate && mantissa < 0x00400000) {
val = 0xFF; // NaN: x < E8M0_MIN is not representable without saturation
return;
}
bool round_up;
switch (round_mode) {
case RoundMode::Up:
// Ceiling: round up only when value > 2^(-127). Denorm mantissa == 0x400000
// is exactly 2^(-127) (val=0), so it must NOT round up.
round_up = (mantissa > 0x00400000);
break;
case RoundMode::Down:
// Floor: always keep val=0 (2^(-127)), never increment.
round_up = false;
break;
case RoundMode::Nearest:
default:
// Round to nearest: G=1 and (R=1 or S!=0) means value > midpoint-equivalent.
round_up = (mantissa > 0x00400000);
Comment thread
tianleiwu marked this conversation as resolved.
Outdated
break;
}
if (round_up) {
val = 0x01; // 2^(-126)
} else {
val = 0xFF; // NaN (subnormals are below E8M0 min for saturate=false)
val = 0x00; // 2^(-127)
}
return;
}

// Normal float32: value is 2^(exponent - 127) * (1 + mantissa/2^23)
// We need to round to the nearest power of 2.
// Round half up: round to next power of 2 when mantissa >= 0.5
// (i.e., when the float value is >= 1.5 * nearest lower power of 2)
// This aligns with the OCP Microscaling Formats (MX) spec for E8M0 scaling factors.
if (mantissa >= 0x00400000) { // >= 0.5
// Normal float32: value is 2^(exponent - 127) * (1 + mantissa/2^23).
// Values with exponent=254 and mantissa>0 are in (2^127, 2^128). Since 2^128
// is not representable in E8M0 (val 255 = NaN), without saturation these
// values cannot be rounded to any valid E8M0 value and must become NaN.
if (!saturate && exponent == 0xFE && mantissa != 0) {
val = 0xFF; // NaN: x > E8M0_MAX is not representable without saturation
return;
}
// Round to the nearest power of 2 using the ONNX semantics:
// Up (ceiling): round up when the float is not exactly a power of 2 (mantissa > 0).
// Down (floor): never round up; always keep the lower exponent.
// Nearest: G bit (bit 22) determines direction -- round up when mantissa >= 0x400000.
// For normal floats lsb of exponent is always considered 1, so ties
// round to the higher power of 2 (round-half-up).
bool round_up;
switch (round_mode) {
case RoundMode::Up:
round_up = (mantissa > 0);
break;
case RoundMode::Down:
round_up = false;
break;
case RoundMode::Nearest:
default:
round_up = (mantissa >= 0x00400000);
break;
}
if (round_up) {
exponent += 1;
}

// After rounding, exponent may overflow
// After rounding, exponent may overflow.
if (exponent > 0xFE) {
if (saturate) {
val = 0xFE; // Largest finite: 2^127
Expand Down
Loading
Loading