Skip to content
This repository has been archived by the owner on Aug 11, 2020. It is now read-only.

Add round-to-nearest-even rounding to float2half(). #368

Merged
merged 4 commits into from
Jan 28, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 94 additions & 28 deletions mshadow/half.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@
#include <x86intrin.h>
#endif // MSHADOW_USE_F16C

// This flag dictates rounding for the float2half() routine only (used generally on Windows),
// not the f16c lib or cuda v7.5 (or later) behavior which is fixed at round-to-nearest-even.
#ifndef MSHADOW_HALF_ROUND_TO_NEAREST
#define MSHADOW_HALF_ROUND_TO_NEAREST 1
#endif

#if (MSHADOW_USE_CUDA && CUDA_VERSION >= 7050)
#define MSHADOW_CUDA_HALF 1
#include <cuda_fp16.h>
Expand Down Expand Up @@ -159,12 +165,18 @@ class MSHADOW_ALIGNED(2) half_t {
uint32_t ui;
eric-haibin-lin marked this conversation as resolved.
Show resolved Hide resolved
};

static int const shift = 13;
static int const fp16FractionBits = 10;
static int const fp32FractionBits = 23;
static int32_t const fp32FractionMask = ~(~0u << fp32FractionBits); // == 0x7fffff
static int32_t const fp32HiddenBit = 1 << fp32FractionBits; // == 0x800000
static int const shift = fp32FractionBits - fp16FractionBits; // == 13
static int const shiftSign = 16;
static int32_t const expAdjust = 127 - 15; // exp32-127 = exp16-15, so exp16 = exp32 - (127-15)

static int32_t const infN = 0x7F800000; // flt32 infinity
static int32_t const maxN = 0x477FE000; // max flt16 normal as a flt32
static int32_t const maxN = 0x477FFFFF; // max flt32 that's a flt16 normal after >> by shift
static int32_t const minN = 0x38800000; // min flt16 normal as a flt32
static int32_t const maxZ = 0x33000000; // max fp32 number that's still rounded to zero in fp16
static int32_t const signN = 0x80000000; // flt32 sign bit

static int32_t const infC = infN >> shift;
Expand All @@ -183,37 +195,91 @@ class MSHADOW_ALIGNED(2) half_t {
static int32_t const minD = minC - subC - 1;

MSHADOW_XINLINE uint16_t float2half(const float& value) const {
Bits v, s;
Bits v;
v.f = value;
uint32_t sign = v.si & signN;
v.si ^= sign;
sign >>= shiftSign; // logical shift
s.si = mulN;
s.si = s.f * v.f; // correct subnormals
v.si ^= (s.si ^ v.si) & -(minN > v.si);
v.si ^= (infN ^ v.si) & -((infN > v.si) & (v.si > maxN));
v.si ^= (nanN ^ v.si) & -((nanN > v.si) & (v.si > infN));
v.ui >>= shift; // logical shift
v.si ^= ((v.si - maxD) ^ v.si) & -(v.si > maxC);
v.si ^= ((v.si - minD) ^ v.si) & -(v.si > subC);
return v.ui | sign;
uint32_t sign = v.si & signN; // grab sign bit
v.si ^= sign; // clear sign bit from v
sign >>= shiftSign; // logical shift sign to fp16 position

if (v.si <= maxZ) {
// Handle eventual zeros here to ensure vshift will not exceed 32 below.
v.ui = 0;
} else if (v.si < minN) {
// Handle denorms
uint32_t exp32 = v.ui >> fp32FractionBits;
int32_t exp16 = exp32 - expAdjust;
// If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.
// Smaller (so negative) exp16 values should result in greater right shifts.
uint32_t vshift = 1 - exp16;
uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
v.ui = significand >> vshift;
// The only time it's *not* OK to add 0x1000 (i.e. half the flt16 fraction lsb) is
// when the lsb of the flt16 fraction == 0 (so not rounding up to even) and the additional
// bits to the right of the lsb are 1000... (including flt32 significand bits
// that may be lost during the above vshift). The first term below will always
// be true for vshift >=12 (since even the 'hidden bit' has been shifted to the
// right of the '1' bit in 0x1000). And when vshift <= 11, both terms combine to make
// the proper test of the flt32 significand bits, including those lost during the vshift.
#if MSHADOW_HALF_ROUND_TO_NEAREST == 1
// Rounding may increase the exponent to 1, but that's OK.
v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;
#endif
} else if (v.si <= maxN) {
// Handle norms
#if MSHADOW_HALF_ROUND_TO_NEAREST == 1
// Rounding may increase the exponent, possibly creating an inf, but that's OK.
v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;
#endif
v.ui -= expAdjust << fp32FractionBits;
} else if (v.si <= infN) {
v.si = infN;
} else if (v.si < nanN) {
v.si = nanN;
}

v.ui >>= shift;
return sign | (v.ui & 0x7fff);
}

// Same as above routine, except for addition of volatile keyword
MSHADOW_XINLINE uint16_t float2half(const volatile float& value) const volatile { // NOLINT (*)
Bits v, s;
Bits v;
v.f = value;
uint32_t sign = v.si & signN;
v.si ^= sign;
sign >>= shiftSign; // logical shift
s.si = mulN;
s.si = s.f * v.f; // correct subnormals
v.si ^= (s.si ^ v.si) & -(minN > v.si);
v.si ^= (infN ^ v.si) & -((infN > v.si) & (v.si > maxN));
v.si ^= (nanN ^ v.si) & -((nanN > v.si) & (v.si > infN));
v.ui >>= shift; // logical shift
v.si ^= ((v.si - maxD) ^ v.si) & -(v.si > maxC);
v.si ^= ((v.si - minD) ^ v.si) & -(v.si > subC);
return v.ui | sign;
uint32_t sign = v.si & signN; // grab sign bit
v.si ^= sign; // clear sign bit from v
sign >>= shiftSign; // logical shift sign to fp16 position

if (v.si <= maxZ) {
// Handle eventual zeros here to ensure vshift will not exceed 32 below.
v.ui = 0;
} else if (v.si < minN) {
// Handle denorms
uint32_t exp32 = v.ui >> fp32FractionBits;
int32_t exp16 = exp32 - expAdjust;
// If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.
// Smaller (so negative) exp16 values should result in greater right shifts.
uint32_t vshift = 1 - exp16;
uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
v.ui = significand >> vshift;
#if MSHADOW_HALF_ROUND_TO_NEAREST == 1
// Rounding may increase the exponent to 1, but that's OK.
v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;
#endif
} else if (v.si <= maxN) {
// Handle norms
#if MSHADOW_HALF_ROUND_TO_NEAREST == 1
// Rounding may increase the exponent, possibly creating an inf, but that's OK.
v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;
#endif
v.ui -= expAdjust << fp32FractionBits;
} else if (v.si <= infN) {
v.si = infN;
} else if (v.si < nanN) {
v.si = nanN;
}

v.ui >>= shift;
return sign | (v.ui & 0x7fff);
}

MSHADOW_XINLINE float half2float(const uint16_t& value) const {
Expand Down