Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions sycl/include/sycl/ext/intel/experimental/bfloat16.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#pragma once

#include <CL/__spirv/spirv_ops.hpp>
#include <CL/sycl/half_type.hpp>

__SYCL_INLINE_NAMESPACE(cl) {
namespace sycl {
Expand Down Expand Up @@ -43,8 +44,11 @@ class [[sycl_detail::uses_aspects(ext_intel_bf16_conversion)]] bfloat16 {
#endif
}

// Direct initialization
bfloat16(const storage_t &a) : value(a) {}
static bfloat16 from_bits(const storage_t &a) {
bfloat16 res;
res.value = a;
return res;
}

// Implicit conversion from float to bfloat16
bfloat16(const float &a) { value = from_float(a); }
Expand All @@ -56,9 +60,10 @@ class [[sycl_detail::uses_aspects(ext_intel_bf16_conversion)]] bfloat16 {

// Implicit conversion from bfloat16 to float
operator float() const { return to_float(value); }
operator sycl::half() const { return to_float(value); }

// Get raw bits representation of bfloat16
operator storage_t() const { return value; }
storage_t raw() const { return value; }

// Logical operators (!,||,&&) are covered if we can cast to bool
explicit operator bool() { return to_float(value) != 0.0f; }
Expand Down
14 changes: 13 additions & 1 deletion sycl/test/extensions/bfloat16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using sycl::ext::intel::experimental::bfloat16;

SYCL_EXTERNAL uint16_t some_bf16_intrinsic(uint16_t x, uint16_t y);
SYCL_EXTERNAL void foo(long x, sycl::half y);

__attribute__((noinline)) float op(float a, float b) {
// CHECK: define {{.*}} spir_func float @_Z2opff(float [[a:%.*]], float [[b:%.*]])
Expand All @@ -27,11 +28,22 @@ __attribute__((noinline)) float op(float a, float b) {
// CHECK-NOT: uitofp
// CHECK-NOT: fptoui

bfloat16 D = some_bf16_intrinsic(A, C);
bfloat16 D = bfloat16::from_bits(some_bf16_intrinsic(A.raw(), C.raw()));
// CHECK: [[D:%.*]] = tail call spir_func zeroext i16 @_Z19some_bf16_intrinsictt(i16 zeroext [[A]], i16 zeroext [[C]])
// CHECK-NOT: uitofp
// CHECK-NOT: fptoui

long L = bfloat16(3.14f);
// CHECK: [[L_bfloat16:%.*]] = tail call spir_func zeroext i16 @_Z27__spirv_ConvertFToBF16INTELf(float 0x40091EB860000000)
// CHECK: [[L_float:%.*]] = tail call spir_func float @_Z27__spirv_ConvertBF16ToFINTELt(i16 zeroext [[L_bfloat16]])
// CHECK: [[L:%.*]] = fptosi float [[L_float]] to i{{32|64}}

sycl::half H = bfloat16(2.71f);
// CHECK: [[H_bfloat16:%.*]] = tail call spir_func zeroext i16 @_Z27__spirv_ConvertFToBF16INTELf(float 0x4005AE1480000000)
// CHECK: [[H_float:%.*]] = tail call spir_func float @_Z27__spirv_ConvertBF16ToFINTELt(i16 zeroext [[H_bfloat16]])
// CHECK: [[H:%.*]] = fptrunc float [[H_float]] to half
foo(L, H);

return D;
// CHECK: [[RetVal:%.*]] = tail call spir_func float @_Z27__spirv_ConvertBF16ToFINTELt(i16 zeroext [[D]])
// CHECK: ret float [[RetVal]]
Expand Down