Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 77 additions & 73 deletions docs/OperatorKernels.md

Large diffs are not rendered by default.

206 changes: 206 additions & 0 deletions include/onnxruntime/core/common/float8.h
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,150 @@
}
}

// Float8E8M0
// 8-bit floating point with 8 exponent bits and 0 mantissa bits (no sign bit).
// All representable values are powers of two: 2^(val - 127).
// Special value: 0xFF = NaN.
struct Float8E8M0 {
uint8_t val{0}; // Raw 8-bit exponent value. Represents 2^(val - 127). 0xFF = NaN.
#if defined(__HIP__)
ORT_HOST_DEVICE Float8E8M0() = default;
#else
Float8E8M0() = default;
#endif

struct FromBitsT {};
static constexpr ORT_HOST_DEVICE FromBitsT FromBits() { return FromBitsT(); }
constexpr ORT_HOST_DEVICE Float8E8M0(unsigned char bits, FromBitsT) : val(bits) {}

inline explicit ORT_HOST_DEVICE Float8E8M0(float v, bool saturate = true) {
uint32_t b;
std::memcpy(&b, &v, sizeof(b));

uint32_t sign = b & 0x80000000;
uint32_t exponent = (b & 0x7F800000) >> 23;
uint32_t mantissa = b & 0x007FFFFF;

// Negative values (except -0) cannot be represented
if (sign && (exponent != 0 || mantissa != 0)) {
Comment thread
tianleiwu marked this conversation as resolved.
if (saturate) {
// Saturate negative to smallest positive (2^-127)
val = 0x00;
} else {
val = 0xFF; // NaN
}
return;
}

// NaN
if (exponent == 0xFF && mantissa != 0) {
val = 0xFF;
return;
}

// Infinity
if (exponent == 0xFF && mantissa == 0) {
if (saturate) {
val = 0xFE; // Largest finite value: 2^127
} else {
val = 0xFF; // NaN (no infinity in this format)
}
return;
}

// Zero or denormalized float32 maps to smallest value
if (exponent == 0) {
val = 0x00; // 2^(-127)
return;
}

// Normal float32: value is 2^(exponent - 127) * (1 + mantissa/2^23)
// We need to round to the nearest power of 2.
// If mantissa >= 0.5 (i.e., mantissa >= 2^22), round up the exponent.
if (mantissa >= 0x00400000) { // >= 0.5
Comment thread
tianleiwu marked this conversation as resolved.
exponent += 1;
}

// After rounding, exponent may overflow
if (exponent > 0xFE) {
if (saturate) {
val = 0xFE; // Largest finite: 2^127
} else {
val = 0xFF; // NaN
}
return;
}

val = static_cast<uint8_t>(exponent);
}

inline ORT_HOST_DEVICE bool IsNaN() const {
return val == 0xFF;
}

inline ORT_HOST_DEVICE float ToFloat() const {
if (val == 0xFF) {
// NaN
uint32_t res = 0x7FC00000; // quiet NaN
float float_res;
std::memcpy(&float_res, &res, sizeof(float));
return float_res;
}

if (val == 0) {
// 2^(-127) is a denormalized float32: sign=0, exponent=0, mantissa=2^22
// Denorm value = 2^(-126) * (mantissa/2^23) = 2^(-126) * 0.5 = 2^(-127)
uint32_t res = 0x00400000;
float float_res;
std::memcpy(&float_res, &res, sizeof(float));
return float_res;
}

// For val 1-254: Value is 2^(val - 127)
// In float32: exponent field = val, mantissa = 0, sign = 0
uint32_t res = static_cast<uint32_t>(val) << 23;
float float_res;
std::memcpy(&float_res, &res, sizeof(float));
return float_res;
}

inline ORT_HOST_DEVICE operator float() const { return ToFloat(); }
};

inline ORT_HOST_DEVICE bool operator==(const Float8E8M0& left, const Float8E8M0& right) { return left.val == right.val; }
inline ORT_HOST_DEVICE bool operator!=(const Float8E8M0& left, const Float8E8M0& right) { return left.val != right.val; }
inline ORT_HOST_DEVICE bool operator<(const Float8E8M0& left, const Float8E8M0& right) { return left.val < right.val; }

// User defined suffixes to make it easier to declare
// initializers with Float8E8M0 from unsigned char
#if !defined(__CUDACC__) && !defined(__HIPCC__)

inline Float8E8M0 operator""_f8e8m0(unsigned long long int v) {

Check warning on line 806 in include/onnxruntime/core/common/float8.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Use int16_t/int64_t/etc, rather than the C type long [runtime/int] [4] Raw Output: include/onnxruntime/core/common/float8.h:806: Use int16_t/int64_t/etc, rather than the C type long [runtime/int] [4]
return Float8E8M0(narrow<uint8_t>(v), Float8E8M0::FromBits());
}

inline Float8E8M0 operator""_f8e8m0p8(long double v) {
return Float8E8M0(static_cast<float>(v), true);
}

#endif

inline void Float8E8M0ToFloat(const Float8E8M0* blf, float* flt, size_t size) {
auto src = blf;
auto d = flt;
for (; size != 0; ++src, ++d, --size) {
*d = src->ToFloat();
}
}

inline void FloatToFloat8E8M0(const float* flt, Float8E8M0* blf, size_t size, bool saturate) {
auto src = flt;
auto d = blf;
for (; size != 0; ++src, ++d, --size) {
new (d) Float8E8M0(*src, saturate);
}
}

} // namespace onnxruntime

namespace std {
Expand Down Expand Up @@ -932,6 +1076,68 @@
static constexpr auto tinyness_before = false;
};

template <>
class numeric_limits<onnxruntime::Float8E8M0> {

Check warning on line 1080 in include/onnxruntime/core/common/float8.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <limits> for numeric_limits<> [build/include_what_you_use] [4] Raw Output: include/onnxruntime/core/common/float8.h:1080: Add #include <limits> for numeric_limits<> [build/include_what_you_use] [4]
public:
// Float8E8M0 has no sign bit, so lowest == min (smallest positive normal)
static constexpr onnxruntime::Float8E8M0 lowest() {
return onnxruntime::Float8E8M0(0x00, onnxruntime::Float8E8M0::FromBits()); // 2^-127
}

static constexpr onnxruntime::Float8E8M0 max() {
return onnxruntime::Float8E8M0(0xFE, onnxruntime::Float8E8M0::FromBits()); // 2^127
}

static constexpr onnxruntime::Float8E8M0 min() {
return onnxruntime::Float8E8M0(0x00, onnxruntime::Float8E8M0::FromBits()); // 2^-127
}

static constexpr onnxruntime::Float8E8M0 denorm_min() {
return onnxruntime::Float8E8M0(0x00, onnxruntime::Float8E8M0::FromBits()); // No denormals
}

static constexpr onnxruntime::Float8E8M0 epsilon() {
return onnxruntime::Float8E8M0(0x7F, onnxruntime::Float8E8M0::FromBits()); // 2^0 = 1.0 (next representable after 1.0 is 2.0, so eps = 1.0)
}

static constexpr onnxruntime::Float8E8M0 round_error() {
return onnxruntime::Float8E8M0(0x7F, onnxruntime::Float8E8M0::FromBits()); // 1.0
}

static constexpr onnxruntime::Float8E8M0 infinity() {
// no infinity, returns quiet NaN instead
return quiet_NaN();
}

static constexpr onnxruntime::Float8E8M0 quiet_NaN() {
return onnxruntime::Float8E8M0(0xFF, onnxruntime::Float8E8M0::FromBits());
}

static constexpr bool is_specialized = true;
static constexpr bool is_signed = false;
static constexpr bool is_integer = false;
static constexpr bool is_exact = false;
static constexpr bool has_infinity = false;
static constexpr bool has_quiet_NaN = true;
static constexpr bool has_signaling_NaN = false;
static constexpr auto has_denorm = false;
static constexpr auto has_denorm_loss = false;
static constexpr auto round_style = round_to_nearest;
static constexpr bool is_iec559 = false;
static constexpr bool is_bounded = true;
static constexpr bool is_modulo = false;
static constexpr int digits = 1;
static constexpr int digits10 = 0;
static constexpr int max_digits10 = 1;
static constexpr int radix = 2;
static constexpr int min_exponent = -126;
static constexpr int min_exponent10 = -38;
static constexpr int max_exponent = 128;
static constexpr int max_exponent10 = 38;
static constexpr auto traps = false;
static constexpr auto tinyness_before = false;
};

} // namespace std

#endif // DISABLE_FLOAT8_TYPES
4 changes: 2 additions & 2 deletions include/onnxruntime/core/framework/data_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ struct IsTensorContainedType : public IsAnyOf<T, float, uint8_t, int8_t, uint16_
Int4x2, UInt4x2, Int2x4, UInt2x4
#if !defined(DISABLE_FLOAT8_TYPES)
,
Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ
Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ, Float8E8M0
#endif
#if !defined(DISABLE_FLOAT4_TYPES)
,
Expand All @@ -310,7 +310,7 @@ struct IsSparseTensorContainedType : public IsAnyOf<T, float, uint8_t, int8_t, u
Int4x2, UInt4x2, Int2x4, UInt2x4
#if !defined(DISABLE_FLOAT8_TYPES)
,
Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ
Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ, Float8E8M0
#endif
#if !defined(DISABLE_FLOAT4_TYPES)
,
Expand Down
12 changes: 12 additions & 0 deletions include/onnxruntime/core/framework/data_types_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ namespace utils {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ: \
function<Float8E5M2FNUZ>(__VA_ARGS__); \
break; \
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E8M0: \
function<Float8E8M0>(__VA_ARGS__); \
break; \
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT4E2M1: \
function<Float4E2M1x2>(__VA_ARGS__); \
break; \
Expand Down Expand Up @@ -168,6 +171,9 @@ namespace utils {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ: \
retval = function<Float8E5M2FNUZ>(__VA_ARGS__); \
break; \
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E8M0: \
retval = function<Float8E8M0>(__VA_ARGS__); \
break; \
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT4E2M1: \
retval = function<Float4E2M1x2>(__VA_ARGS__); \
break; \
Expand Down Expand Up @@ -373,6 +379,9 @@ namespace utils {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ: \
function<Float8E5M2FNUZ>(__VA_ARGS__); \
break; \
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E8M0: \
function<Float8E8M0>(__VA_ARGS__); \
break; \
case ONNX_NAMESPACE::TensorProto_DataType_INT4: \
function<Int4x2>(__VA_ARGS__); \
break; \
Expand Down Expand Up @@ -445,6 +454,9 @@ namespace utils {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ: \
retval = function<Float8E5M2FNUZ>(__VA_ARGS__); \
break; \
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E8M0: \
retval = function<Float8E8M0>(__VA_ARGS__); \
break; \
case ONNX_NAMESPACE::TensorProto_DataType_INT4: \
retval = function<Int4x2>(__VA_ARGS__); \
break; \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ template <>
constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType<Float8E5M2FNUZ>() {
return ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ;
}
template <>
constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType<Float8E8M0>() {
return ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E8M0;
}

#endif

Expand Down
2 changes: 2 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ typedef enum ONNXTensorElementDataType {
ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4, // maps to a pair of packed int4 values (size == 1 byte)
// Float4 types were introduced in ONNX 1.18. See https://onnx.ai/onnx/technical/float4.html
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT4E2M1, // maps to a pair of packed float4 values (size == 1 byte)
// Float8E8M0 type: 8-bit float with 8 exponent bits, 0 mantissa bits, no sign bit
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E8M0, // Non-IEEE floating-point format, all values are powers of two
Comment thread
tianleiwu marked this conversation as resolved.
Outdated
// Int2 types were introduced in ONNX 1.20. See https://onnx.ai/onnx/technical/int2.html
ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT2, // maps to 4 packed uint2 values (size == 1 byte)
ONNX_TENSOR_ELEMENT_DATA_TYPE_INT2, // maps to 4 packed int2 values (size == 1 byte)
Expand Down
Loading
Loading