Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
6f0735f
Add basic fp8 definitions and prn-generator
geyyer May 8, 2023
d3929cb
Format
geyyer May 8, 2023
21481b4
Add fp8<->fp32 type_convert
geyyer May 8, 2023
f07a74d
Format
geyyer May 8, 2023
5038b95
Split type_convert and cast_to/from_f8
geyyer May 11, 2023
872093b
Format
geyyer May 11, 2023
be7e055
Minor fix
geyyer May 11, 2023
9e24e2b
Minor fix
geyyer May 12, 2023
185fb54
Move fp8 utils to a separate header
geyyer May 12, 2023
4089bc6
Add elementwise ops
geyyer May 12, 2023
4ddb62b
Add fp8_convert_sr
geyyer May 12, 2023
653f951
Format
geyyer May 12, 2023
2818735
Add element op
geyyer May 12, 2023
f2cf634
Merge branch 'develop' into lwpck-726
geyyer May 12, 2023
fd2e630
Eliminate magic numbers
geyyer May 15, 2023
114c341
Split f8_convert_sr in host and device
geyyer May 15, 2023
a30a012
Format
geyyer May 16, 2023
b9bf7fb
Add some constexpr
geyyer May 18, 2023
dbd20ec
Add a datatype test
geyyer May 18, 2023
ed0cb72
Format
geyyer May 18, 2023
0818710
Another format
geyyer May 18, 2023
0c46096
Add fp8<->fp16 tests
geyyer May 23, 2023
052ab48
Update type_converts
geyyer May 23, 2023
c1ba7c6
Format
geyyer May 23, 2023
532bbe5
Add fp16 casting functions
geyyer May 23, 2023
502942f
Format
geyyer May 23, 2023
789862c
Use seed as a runtime arg
geyyer May 23, 2023
c5e2295
Use element location for PRNG
geyyer May 23, 2023
8107bbb
Format
geyyer May 23, 2023
cf0845a
Merge branch 'develop' into lwpck-726
geyyer May 23, 2023
2e7e564
Add fp8<->fp16 to PassThrough element op
geyyer May 24, 2023
8386868
Clean up
geyyer May 24, 2023
ee568bc
Merge branch 'develop' into lwpck-726
geyyer May 24, 2023
f1c2ec7
Merge host and device implementations
geyyer May 24, 2023
f730c3f
Add comments on rounding modes
geyyer Jun 9, 2023
f61c770
Remove leftover code
geyyer Jun 9, 2023
d6a666f
Put type_converts into a separate header
geyyer Jun 9, 2023
c208a8a
Put random number gen to a separate header
geyyer Jun 9, 2023
562ec12
Rearrange f8_utils' namespaces
geyyer Jun 9, 2023
08e263e
Refactor type_convert.hpp
geyyer Jun 15, 2023
2ee1c0a
Move f8_t definition
geyyer Jun 16, 2023
a734203
Merge branch 'develop' into lwpck-726
geyyer Jun 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
263 changes: 1 addition & 262 deletions include/ck/utility/data_type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

#pragma once

#include "ck/utility/f8_utils.hpp"
#include "ck/utility/statically_indexed_array.hpp"
#include "ck/utility/type_convert.hpp"
Comment thread
aosewski marked this conversation as resolved.
Outdated

namespace ck {

Expand Down Expand Up @@ -960,267 +960,6 @@ using f8x16_t = typename vector_type<f8_t, 16>::type;
using f8x32_t = typename vector_type<f8_t, 32>::type;
using f8x64_t = typename vector_type<f8_t, 64>::type;

// Convert X to Y
template <typename Y, typename X>
__host__ __device__ constexpr Y type_convert(X x)
{
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);

return static_cast<Y>(x);
}

// convert bfp16 to fp32
template <>
inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t x)
{
union
{
uint32_t int32;
float fp32;
} u = {uint32_t(x) << 16};

return u.fp32;
}

// convert fp32 to bfp16
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float x)
{
union
{
float fp32;
uint32_t int32;
} u = {x};

return uint16_t(u.int32 >> 16);
}

// convert bfp16 to fp16 via fp32
template <>
inline __host__ __device__ constexpr half_t type_convert<half_t, bhalf_t>(bhalf_t x)
{
float x_fp32 = type_convert<float>(x);

return static_cast<half_t>(x_fp32);
}

// convert fp16 to bfp16 via fp32
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, half_t>(half_t x)
{
float x_fp32 = static_cast<float>(x);

return type_convert<bhalf_t>(x_fp32);
}

// convert bfp16 to int32 via fp32
template <>
inline __host__ __device__ constexpr int32_t type_convert<int32_t, bhalf_t>(bhalf_t x)
{
float x_fp32 = type_convert<float>(x);

return static_cast<int32_t>(x_fp32);
}

// convert int32 to bfp16 via fp32
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int32_t>(int32_t x)
{
float x_fp32 = static_cast<float>(x);

return type_convert<bhalf_t>(x_fp32);
}

// convert bfp16 to int8 via fp32
template <>
inline __host__ __device__ constexpr int8_t type_convert<int8_t, bhalf_t>(bhalf_t x)
{
float x_fp32 = type_convert<float>(x);

return static_cast<int8_t>(x_fp32);
}

// convert int8 to bfp16 via fp32
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_t x)
{
float x_fp32 = static_cast<float>(x);

return type_convert<bhalf_t>(x_fp32);
}

// convert fp32 to fp8
template <>
inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
{
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x, rng);
}

// convert fp8 to fp32
template <>
inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
{
constexpr bool negative_zero_nan = true;
return cast_from_f8<float, negative_zero_nan>(x);
}

// convert fp16 to fp8
template <>
inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
{
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return cast_to_f8<half_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x,
rng);
}

// convert fp8 to fp16
template <>
inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
{
constexpr bool negative_zero_nan = true;
return cast_from_f8<half_t, negative_zero_nan>(x);
}

// Declare a template function for bf16 conversion using RTN
template <typename Y, typename X>
__host__ __device__ constexpr Y bf16_convert_rtn(X x);

// Convert fp32 to bf16 with RTN if higher precision is needed
template <>
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, float>(float x)
{
union
{
float fp32;
uint32_t int32;
} u = {x};

// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
bool flag0 = ~u.int32 & 0x7f800000;

// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bfloat16's mantissa bits are all 0.
bool flag1 = !flag0 && (u.int32 & 0xffff);

u.int32 += flag0 ? 0x7fff + ((u.int32 >> 16) & 1) : 0; // Round to nearest, round to even
u.int32 |= flag1 ? 0x10000 : 0x0; // Preserve signaling NaN

return uint16_t(u.int32 >> 16);
}

// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed
template <>
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(half_t x)
{
float x_fp32 = static_cast<float>(x);

return bf16_convert_rtn<bhalf_t>(x_fp32);
}

// Pseudo random number generator
// version for fp32
template <typename T, uint32_t seed_t, std::enable_if_t<std::is_same<float, T>{}, bool> = false>
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t)
{
uint32_t x = *(reinterpret_cast<uint32_t*>(&val));
uint32_t drop_bits = uint32_t(x) & 0xFFFFu;
drop_bits ^= x >> 16;
drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5);
drop_bits *= 0x7000149;
// NOTE: If id is in 64 bit, we are only using lower 32 bit.
// So, it can have an effect of using same id for multiple elements when the id is very
// large!
uint32_t rng = (drop_bits ^ 0x13371337 ^ (id * 229791) ^ seed);
return rng;
}

// version for fp16
template <typename T, uint32_t seed_t, std::enable_if_t<std::is_same<half_t, T>{}, bool> = false>
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t)
{
uint16_t x = *(reinterpret_cast<uint16_t*>(&val));
uint32_t drop_bits = uint32_t(x) & 0xFFFFu;
drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5);
drop_bits *= 0x7000149;
// NOTE: If id is in 64 bit, we are only using lower 32 bit.
// So, it can have an effect of using same id for multiple elements when the id is very
// large!
uint32_t rng = (drop_bits ^ 0x13371337 ^ (id * 229791) ^ seed);
return rng;
}

// return 0 if data is not fp16 or fp32
template <typename T,
uint32_t seed_t,
std::enable_if_t<!(std::is_same<float, T>{} || std::is_same<half_t, T>{}), bool> = false>
__host__ __device__ uint32_t prand_generator(int id, T val, uint32_t seed = seed_t)
{
std::ignore = id;
std::ignore = val;
std::ignore = seed;

return 0;
}

// Declare a template function for fp8 conversion using SR
template <typename Y, typename X>
__host__ __device__ constexpr Y f8_convert_sr(X x);

// convert fp32 to fp8 with stochastic rounding
template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
{
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42;
// as thread id is not available on host, use 0 for prn generation
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
return cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x, rng);
}

// convert fp16 to fp8 with stochastic rounding
template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
{
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42;
// as thread id is not available on host, use 0 for prn generation
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return cast_to_f8<half_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x,
rng);
}

template <typename T>
struct NumericLimits
{
Expand Down
45 changes: 23 additions & 22 deletions include/ck/utility/f8_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,26 @@

#pragma once

#include "ck/utility/statically_indexed_array.hpp"

namespace ck {
Comment thread
aosewski marked this conversation as resolved.

using f8_t = uint8_t;
using half_t = _Float16;

// fp8 rounding modes
// use standard for rounding to nearest, the faster one
// use stochastic for stochastic rounding, helps to avoid error accumulation
enum class f8_rounding_mode
{
standard,
stochastic
Comment thread
aosewski marked this conversation as resolved.
};

} // namespace ck

namespace ck::utils {

namespace {

template <typename T, bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
{
Expand Down Expand Up @@ -127,17 +133,6 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
return (sign << (f8_exp + f8_mant)) | (exponent << f8_mant) | mantissa;
}

template <typename T, bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ f8_t cast_to_f8(T x, uint32_t rng)
{
// check datatype
constexpr bool is_half = std::is_same<T, half_t>::value;
constexpr bool is_float = std::is_same<T, float>::value;
static_assert(is_half || is_float, "Only half and float can be casted to f8.");

return run_cast_to_f8<T, negative_zero_nan, clip, stoch>(x, rng);
}

template <typename T, bool negative_zero_nan>
__host__ __device__ T run_cast_from_f8(f8_t x)
{
Expand Down Expand Up @@ -207,15 +202,8 @@ __host__ __device__ T run_cast_from_f8(f8_t x)
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int sh = 1 + __builtin_clz(mantissa) - ((1 + type_exp + type_mant) - f8_mant);
mantissa <<= sh;
exponent += 1 - sh;
/*
exponent++;
while(mantissa<(1<<wm)) {
mantissa <<= 1;
exponent--;
}
*/
mantissa &= ((1 << f8_mant) - 1);
exponent += 1 - sh;
}
exponent += exp_low_cutoff - 1;
mantissa <<= type_mant - f8_mant;
Expand All @@ -232,6 +220,19 @@ __host__ __device__ T run_cast_from_f8(f8_t x)
return *(reinterpret_cast<const T*>(&retval));
}

} // namespace

template <typename T, bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ f8_t cast_to_f8(T x, uint32_t rng)
{
// check datatype
constexpr bool is_half = std::is_same<T, half_t>::value;
constexpr bool is_float = std::is_same<T, float>::value;
static_assert(is_half || is_float, "Only half and float can be casted to f8.");

return run_cast_to_f8<T, negative_zero_nan, clip, stoch>(x, rng);
}

template <typename T, bool negative_zero_nan>
__host__ __device__ T cast_from_f8(f8_t x)
{
Expand All @@ -247,4 +248,4 @@ __host__ __device__ T cast_from_f8(f8_t x)
return run_cast_from_f8<T, negative_zero_nan>(x);
}

} // namespace ck
} // namespace ck::utils
Loading