Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
6f0735f
Add basic fp8 definitions and prn-generator
geyyer May 8, 2023
d3929cb
Format
geyyer May 8, 2023
21481b4
Add fp8<->fp32 type_convert
geyyer May 8, 2023
f07a74d
Format
geyyer May 8, 2023
5038b95
Split type_convert and cast_to/from_f8
geyyer May 11, 2023
872093b
Format
geyyer May 11, 2023
be7e055
Minor fix
geyyer May 11, 2023
9e24e2b
Minor fix
geyyer May 12, 2023
185fb54
Move fp8 utils to a separate header
geyyer May 12, 2023
4089bc6
Add elementwise ops
geyyer May 12, 2023
4ddb62b
Add fp8_convert_sr
geyyer May 12, 2023
653f951
Format
geyyer May 12, 2023
2818735
Add element op
geyyer May 12, 2023
f2cf634
Merge branch 'develop' into lwpck-726
geyyer May 12, 2023
fd2e630
Eliminate magic numbers
geyyer May 15, 2023
114c341
Split f8_convert_sr in host and device
geyyer May 15, 2023
a30a012
Format
geyyer May 16, 2023
b9bf7fb
Add some constexpr
geyyer May 18, 2023
dbd20ec
Add a datatype test
geyyer May 18, 2023
ed0cb72
Format
geyyer May 18, 2023
0818710
Another format
geyyer May 18, 2023
0c46096
Add fp8<->fp16 tests
geyyer May 23, 2023
052ab48
Update type_converts
geyyer May 23, 2023
c1ba7c6
Format
geyyer May 23, 2023
532bbe5
Add fp16 casting functions
geyyer May 23, 2023
502942f
Format
geyyer May 23, 2023
789862c
Use seed as a runtime arg
geyyer May 23, 2023
c5e2295
Use element location for PRNG
geyyer May 23, 2023
8107bbb
Format
geyyer May 23, 2023
cf0845a
Merge branch 'develop' into lwpck-726
geyyer May 23, 2023
2e7e564
Add fp8<->fp16 to PassThrough element op
geyyer May 24, 2023
8386868
Clean up
geyyer May 24, 2023
ee568bc
Merge branch 'develop' into lwpck-726
geyyer May 24, 2023
f1c2ec7
Merge host and device implementations
geyyer May 24, 2023
f730c3f
Add comments on rounding modes
geyyer Jun 9, 2023
f61c770
Remove leftover code
geyyer Jun 9, 2023
d6a666f
Put type_converts into a separate header
geyyer Jun 9, 2023
c208a8a
Put random number gen to a separate header
geyyer Jun 9, 2023
562ec12
Rearrange f8_utils' namespaces
geyyer Jun 9, 2023
08e263e
Refactor type_convert.hpp
geyyer Jun 15, 2023
2ee1c0a
Move f8_t definition
geyyer Jun 16, 2023
a734203
Merge branch 'develop' into lwpck-726
geyyer Jun 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "ck/utility/data_type.hpp"
#include "ck/utility/math.hpp"
#include "ck/utility/math_v2.hpp"
#include "ck/utility/type_convert.hpp"

namespace ck {
namespace tensor_operation {
Expand Down Expand Up @@ -81,6 +82,36 @@ struct PassThrough
y = x;
}
#endif

template <>
__host__ __device__ void operator()<f8_t, f8_t>(f8_t& y, const f8_t& x) const
{
y = x;
}

template <>
__host__ __device__ void operator()<float, f8_t>(float& y, const f8_t& x) const
{
y = type_convert<float>(x);
}

template <>
__host__ __device__ void operator()<f8_t, float>(f8_t& y, const float& x) const
{
y = type_convert<f8_t>(x);
}

template <>
__host__ __device__ void operator()<half_t, f8_t>(half_t& y, const f8_t& x) const
{
y = type_convert<half_t>(x);
}

template <>
__host__ __device__ void operator()<f8_t, half_t>(f8_t& y, const half_t& x) const
{
y = type_convert<f8_t>(x);
}
};

struct UnaryConvert
Expand Down Expand Up @@ -109,6 +140,23 @@ struct ConvertBF16RTN
}
};

struct ConvertF8SR
{
// convert to fp8 using stochastic rounding (SR)
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const
{
// check Y datatype
static_assert(is_same<Y, f8_t>::value, "Data type is not supported by this operation!");

// check X datatype
static_assert(is_same<X, float>::value || is_same<X, half_t>::value,
"Data type is not supported by this operation!");

y = f8_convert_sr<Y>(x);
}
};

struct Scale
{
__host__ __device__ Scale(float scale) : scale_(scale) {}
Expand Down
1 change: 1 addition & 0 deletions include/ck/utility/common_header.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "ck/utility/tuple.hpp"
#include "ck/utility/tuple_helper.hpp"
#include "ck/utility/type.hpp"
#include "ck/utility/type_convert.hpp"
#include "ck/utility/magic_division.hpp"
#include "ck/utility/c_style_pointer_cast.hpp"
#include "ck/utility/is_known_at_compile_time.hpp"
Expand Down
177 changes: 32 additions & 145 deletions include/ck/utility/data_type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ using half_t = _Float16;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
using int4_t = _BitInt(4);
#endif
using f8_t = uint8_t;

// vector_type
template <typename T, index_t N>
Expand Down Expand Up @@ -142,6 +143,13 @@ struct scalar_type<int4_t>
};
#endif

template <>
struct scalar_type<f8_t>
{
using type = f8_t;
static constexpr index_t vector_size = 1;
};

//
template <typename T>
struct vector_type<T, 1>
Expand Down Expand Up @@ -944,151 +952,13 @@ using int8x16_t = typename vector_type<int8_t, 16>::type;
using int8x32_t = typename vector_type<int8_t, 32>::type;
using int8x64_t = typename vector_type<int8_t, 64>::type;

// Convert X to Y
template <typename Y, typename X>
__host__ __device__ constexpr Y type_convert(X x)
{
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);

return static_cast<Y>(x);
}

// convert bfp16 to fp32
template <>
inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t x)
{
union
{
uint32_t int32;
float fp32;
} u = {uint32_t(x) << 16};

return u.fp32;
}

// convert fp32 to bfp16
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float x)
{
union
{
float fp32;
uint32_t int32;
} u = {x};

return uint16_t(u.int32 >> 16);
}

// convert bfp16 to fp16 via fp32
template <>
inline __host__ __device__ constexpr half_t type_convert<half_t, bhalf_t>(bhalf_t x)
{
float x_fp32 = type_convert<float>(x);

return static_cast<half_t>(x_fp32);
}

// convert fp16 to bfp16 via fp32
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, half_t>(half_t x)
{
float x_fp32 = static_cast<float>(x);

return type_convert<bhalf_t>(x_fp32);
}

// convert bfp16 to int32 via fp32
template <>
inline __host__ __device__ constexpr int32_t type_convert<int32_t, bhalf_t>(bhalf_t x)
{
float x_fp32 = type_convert<float>(x);

return static_cast<int32_t>(x_fp32);
}

// convert int32 to bfp16 via fp32
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int32_t>(int32_t x)
{
float x_fp32 = static_cast<float>(x);

return type_convert<bhalf_t>(x_fp32);
}

// convert bfp16 to int8 via fp32
template <>
inline __host__ __device__ constexpr int8_t type_convert<int8_t, bhalf_t>(bhalf_t x)
{
float x_fp32 = type_convert<float>(x);

return static_cast<int8_t>(x_fp32);
}

// convert int8 to bfp16 via fp32
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_t x)
{
float x_fp32 = static_cast<float>(x);

return type_convert<bhalf_t>(x_fp32);
}

// Declare a template function for bf16 conversion using RTN
template <typename Y, typename X>
__host__ __device__ constexpr Y bf16_convert_rtn(X x);

// Convert fp32 to bf16 with RTN if higher precision is needed
template <>
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, float>(float x)
{
union
{
float fp32;
uint32_t int32;
} u = {x};

// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
bool flag0 = ~u.int32 & 0x7f800000;

// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bfloat16's mantissa bits are all 0.
bool flag1 = !flag0 && (u.int32 & 0xffff);

u.int32 += flag0 ? 0x7fff + ((u.int32 >> 16) & 1) : 0; // Round to nearest, round to even
u.int32 |= flag1 ? 0x10000 : 0x0; // Preserve signaling NaN

return uint16_t(u.int32 >> 16);
}

// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed
template <>
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(half_t x)
{
float x_fp32 = static_cast<float>(x);

return bf16_convert_rtn<bhalf_t>(x_fp32);
}
// f8
using f8x2_t = typename vector_type<f8_t, 2>::type;
using f8x4_t = typename vector_type<f8_t, 4>::type;
using f8x8_t = typename vector_type<f8_t, 8>::type;
using f8x16_t = typename vector_type<f8_t, 16>::type;
using f8x32_t = typename vector_type<f8_t, 32>::type;
using f8x64_t = typename vector_type<f8_t, 64>::type;

template <typename T>
struct NumericLimits
Expand Down Expand Up @@ -1136,4 +1006,21 @@ struct NumericLimits<int4_t>
};
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4

template <>
struct NumericLimits<f8_t>
{
static constexpr uint8_t binary_min = 0x08; // 0b00001000
static constexpr uint8_t binary_max = 0x77; // 0b01110111
static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
static constexpr uint8_t binary_qnan = 0x80; // 0b10000000

__host__ __device__ static constexpr f8_t Min() { return bit_cast<f8_t>(binary_min); }

__host__ __device__ static constexpr f8_t Max() { return bit_cast<f8_t>(binary_max); }

__host__ __device__ static constexpr f8_t Lowest() { return bit_cast<f8_t>(binary_lowest); }

__host__ __device__ static constexpr f8_t QuietNaN() { return bit_cast<f8_t>(binary_qnan); }
};

} // namespace ck
Loading