Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 2 additions & 42 deletions include/onnxruntime/core/framework/data_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,50 +59,10 @@ struct MLFloat16 {
explicit MLFloat16(uint16_t x) : val(x) {}
explicit MLFloat16(float f);

// Taken from https://stackoverflow.com/a/60047308/12627730
float AsFloat(uint32_t x) const {
float out = 0.0f;
std::memcpy(&out, &x, sizeof(x));
return out;
}

// Taken from https://stackoverflow.com/a/60047308/12627730
uint32_t AsUint(float x) const {
uint32_t out = 0;
std::memcpy(&out, &x, sizeof(x));
return out;
}

float HalfToFloat(const uint16_t x) const {
uint16_t half = x;
if (endian::native == endian::big) {
// Taken from https://stackoverflow.com/a/2182184/12627730
half = (x >> 8) | (x << 8);
}

// Taken from https://stackoverflow.com/a/60047308/12627730
// IEEE-754 16-bit floating-point format (without infinity): 1-5-10, exp-15, +-131008.0, +-6.1035156E-5,
// +-5.9604645E-8, 3.311 digits
const uint32_t e = (half & 0x7C00) >> 10; // exponent
const uint32_t m = (half & 0x03FF) << 13; // mantissa
// evil log2 bit hack to count leading zeros in denormalized format
const uint32_t v = AsUint(static_cast<float>(m)) >> 23;
uint32_t full = (half & 0x8000) << 16 | (e != 0) * ((e + 112) << 23 | m) |
((e == 0) & (m != 0)) * ((v - 37) << 23 | ((m << (150 - v)) & 0x007FE000)); // sign : normalized : denormalized

if (endian::native == endian::big) {
// Taken from https://stackoverflow.com/a/2182184/12627730
full = ((full >> 24) & 0xff) | // move byte 3 to byte 0
((full << 8) & 0xff0000) | // move byte 1 to byte 2
((full >> 8) & 0xff00) | // move byte 2 to byte 1
((full << 24) & 0xff000000); // byte 0 to byte 3
}

return AsFloat(full);
}
float ToFloat() const;

operator float() const {
return HalfToFloat(val);
return ToFloat();
}
};

Expand Down
2 changes: 1 addition & 1 deletion include/onnxruntime/core/platform/threadpool.h
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ class ThreadPool {
/**
* Tries to call the given function in parallel, with calls split into (num_batches) batches.
*\param num_batches If it is zero, it will be replaced to the value of DegreeOfParallelism().
*\param fn A std::function or STL style functor with signature of "void f(int32_t);"
*\param fn A std::function or STL style functor with signature of "void f(std::ptrdiff_t);"
* Pitfall: Caller should cap `num_batches` to a reasonable value based on the cost of `fn` and the value of `total`.
*For example, if fn is as simple as: int sum=0; fn = [&](int i){sum +=i;} and `total` is 100, then num_batches should
*be just 1.
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/framework/data_types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ namespace onnxruntime {

MLFloat16::MLFloat16(float f) : val{math::floatToHalf(f)} {}

float MLFloat16::ToFloat() const {
return math::halfToFloat(val);
}

// Return the MLDataType used for a generic Tensor
template <>
MLDataType DataTypeImpl::GetType<Tensor>() {
Expand Down
Loading