Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sycl/include/sycl/ext/intel/math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#pragma once
#include <sycl/ext/intel/math/imf_half_trivial.hpp>
#include <sycl/ext/intel/math/imf_bf16.hpp>
#include <sycl/half_type.hpp>
#include <type_traits>

Expand Down
205 changes: 205 additions & 0 deletions sycl/include/sycl/ext/intel/math/imf_bf16.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
//==-------------------- imf_bf16.hpp - bfloat16 utils ---------------------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// C++ APIs for bfloat16 util functions.
//===----------------------------------------------------------------------===//

#pragma once
#include <sycl/ext/oneapi/experimental/bfloat16.hpp>
#include <type_traits>

using sycl_bfloat16 = sycl::ext::oneapi::experimental::bfloat16;
using _iml_bfloat16_internal = uint16_t;

extern "C" {
float __imf_bfloat162float(_iml_bfloat16_internal);
_iml_bfloat16_internal __imf_float2bfloat16(float);
_iml_bfloat16_internal __imf_float2bfloat16_rd(float);
_iml_bfloat16_internal __imf_float2bfloat16_rn(float);
_iml_bfloat16_internal __imf_float2bfloat16_ru(float);
_iml_bfloat16_internal __imf_float2bfloat16_rz(float);
};

namespace sycl {
__SYCL_INLINE_VER_NAMESPACE(_V1) {
namespace ext {
namespace intel {
namespace math {

// Need to ensure that sycl bfloat16 defined in bfloat16.hpp is compatible
// with uint16_t in layout.
#if __cplusplus >= 201703L
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to check this - C++17 is the minimal supported version.

static_assert(sizeof(sycl_bfloat16) == sizeof(_iml_bfloat16_internal),
"sycl bfloat16 is not compatible with _iml_bfloat16_internal.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need a distinct alias at all?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @aelovikov-intel
What distinct alias do you refer to?
Thanks very much.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't we use the same type in all places? Why do we need both sycl_bfloat16 and _iml_bfloat16_internal?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @aelovikov-intel
The functions defined in sycl::ext::intel::math:: namespace such as hge, hgt... are c++ wrappers for c functions provided in libdevice. All these c++ functions will call corresponding c functions and sycl bfloat16 users will only work with these c++ functions. sycl bfloat16 is a c++ class defined in https://github.com/intel/llvm/blob/sycl/sycl/include/sycl/ext/oneapi/bfloat16.hpp and its current implementation is based on "uint16_t" type. So, our c++ functions such as hge, hgt... can only accept user-visible sycl bfloat16. However, the c functions implemented in libdevice can't use c++ sycl bfloat16 type, they can only accept native C types , so we need to use native C uint16_t type for them. In the future, we may use native bfloat16 type instead emulation based on uint16_t.
Thanks very much.


float bfloat162float(sycl_bfloat16 b) {
return __imf_bfloat162float(__builtin_bit_cast(_iml_bfloat16_internal, b));
}

sycl_bfloat16 float2bfloat16(float f) {
return __builtin_bit_cast(sycl_bfloat16, __imf_float2bfloat16(f));
}

sycl_bfloat16 float2bfloat16_rd(float f) {
return __builtin_bit_cast(sycl_bfloat16, __imf_float2bfloat16_rd(f));
}

sycl_bfloat16 float2bfloat16_rn(float f) {
return __builtin_bit_cast(sycl_bfloat16, __imf_float2bfloat16_rn(f));
}

sycl_bfloat16 float2bfloat16_ru(float f) {
return __builtin_bit_cast(sycl_bfloat16, __imf_float2bfloat16_ru(f));
}

sycl_bfloat16 float2bfloat16_rz(float f) {
return __builtin_bit_cast(sycl_bfloat16, __imf_float2bfloat16_rz(f));
}

bool hisnan(sycl_bfloat b) { return sycl::isnan(bfloat162float(b)); }

bool hisinf(sycl_bfloat b) { return sycl::isinf(bfloat162float(b)); }

bool heq(sycl_bfloat16 b1, sycl_bfloat16 b2) {
if (hisnan(b1) || hisnan(b2))
return false;
return __builtin_bit_cast(_iml_bfloat16_internal, b1) ==
__builtin_bit_cast(_iml_bfloat16_internal, b2);
}

bool hequ(sycl_bfloat16 b1, sycl_bfloat16 b2) {
if (hisnan(b1) || hisnan(b1))
return true;
return __builtin_bit_cast(_iml_bfloat16_internal, b1) ==
__builtin_bit_cast(_iml_bfloat16_internal, b2);
}

bool hge(sycl_bfloat16 b1, sycl_bfloat16 b2) {
if (hisnan(b1) || hisnan(b2))
return false;
float bf1 = bfloat162float(b1);
float bf2 = bfloat162float(b2);
return (bf1 >= bf2);
}

bool hgeu(sycl_bfloat16 b1, sycl_bfloat16 b2) {
if (hisnan(b1) || hisnan(b2))
return true;
float bf1 = bfloat162float(b1);
float bf2 = bfloat162float(b2);
return (bf1 >= bf2);
}

bool hgt(sycl_bfloat16 b1, sycl_bfloat16 b2) {
if (hisnan(b1) || hisnan(b2))
return false;
float bf1 = bfloat162float(b1);
float bf2 = bfloat162float(b2);
return (bf1 > bf2);
}

bool hgtu(sycl_bfloat16 b1, sycl_bfloat16 b2) {
if (hisnan(b1) || hisnan(b2))
return true;
float bf1 = bfloat162float(b1);
float bf2 = bfloat162float(b2);
return (bf1 > bf2);
}

bool hle(sycl_bfloat16 b1, sycl_bfloat16 b2) {
if (hisnan(b1) || hisnan(b2))
return false;
float bf1 = bfloat162float(b1);
float bf2 = bfloat162float(b2);
return (bf1 <= bf2);
}

bool hleu(sycl_bfloat16 b1, sycl_bfloat16 b2) {
if (hisnan(b1) || hisnan(b2))
return true;
float bf1 = bfloat162float(b1);
float bf2 = bfloat162float(b2);
return (bf1 <= bf2);
}

bool hlt(sycl_bfloat16 b1, sycl_bfloat16 b2) {
if (hisnan(b1) || hisnan(b2))
return false;
float bf1 = bfloat162float(b1);
float bf2 = bfloat162float(b2);
return (bf1 < bf2);
}

bool hltu(sycl_bfloat16 b1, sycl_bfloat16 b2) {
if (hisnan(b1) || hisnan(b2))
return true;
float bf1 = bfloat162float(b1);
float bf2 = bfloat162float(b2);
return (bf1 < bf2);
}

sycl_bfloat16 hmax(sycl_bfloat16 b1, sycl_bfloat16 b2) {
_iml_bfloat16_internal ibi = 0x7FC0;
if (hisnan(b1) && hisnan(b2))
return __builtin_bit_cast(sycl_bfloat16, ibi);
if (hisnan(b1))
return b2;
else if (hisnan(b2))
return b1;
else {
return (hgt(b1, b2) ? b1 : b2);
}
}

sycl_bfloat16 hmax_nan(sycl_bfloat16 b1, sycl_bfloat16 b2) {
_iml_bfloat16_internal ibi = 0x7FC0;
if (hisnan(b1) || hisnan(b2))
return __builtin_bit_cast(sycl_bfloat16, ibi);
else
return (hgt(b1, b2) ? b1 : b2);
}

sycl_bfloat16 hmin(sycl_bfloat16 b1, sycl_bfloat16 b2) {
_iml_bfloat16_internal ibi = 0x7FC0;
if (hisnan(b1) && hisnan(b2))
return __builtin_bit_cast(sycl_bfloat16, ibi);
if (hisnan(b1))
return b2;
else if (hisnan(b2))
return b1;
else {
return (hlt(b1, b2) ? b1 : b2);
}
}

sycl_bfloat16 hmin_nan(sycl_bfloat16 b1, sycl_bfloat16 b2) {
_iml_bfloat16_internal ibi = 0x7FC0;
if (hisnan(b1) || hisnan(b2))
return __builtin_bit_cast(sycl_bfloat16, ibi);
else
return (hlt(b1, b2) ? b1 : b2);
}

bool hne(sycl_bfloat16 b1, sycl_bfloat16 b2) {
if (hisnan(b1) || hisnan(b2))
return false;
return __builtin_bit_cast(_iml_bfloat16_internal, b1) !=
__builtin_bit_cast(_iml_bfloat16_internal, b2);
}

bool hneu(sycl_bfloat16 b1, sycl_bfloat16 b2) {
if (hisnan(b1) || hisnan(b2))
return true;
return __builtin_bit_cast(_iml_bfloat16_internal, b1) !=
__builtin_bit_cast(_iml_bfloat16_internal, b2);
}
#endif
} // namespace math
} // namespace intel
} // namespace ext
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
} // namespace sycl