Skip to content
3 changes: 3 additions & 0 deletions sycl/include/CL/__spirv/spirv_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,9 @@ extern SYCL_EXTERNAL void
__spirv_ocl_prefetch(const __attribute__((opencl_global)) char *Ptr,
size_t NumBytes) noexcept;

extern SYCL_EXTERNAL uint16_t __spirv_ConvertFToBF16INTEL(float) noexcept;
extern SYCL_EXTERNAL float __spirv_ConvertBF16ToFINTEL(uint16_t) noexcept;

#else // if !__SYCL_DEVICE_ONLY__

template <typename dataT>
Expand Down
1 change: 1 addition & 0 deletions sycl/include/CL/sycl/feature_test.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ namespace sycl {
#ifndef SYCL_EXT_ONEAPI_MATRIX
#define SYCL_EXT_ONEAPI_MATRIX 2
#endif
#define SYCL_EXT_INTEL_BF16_CONVERSION 1

} // namespace sycl
} // __SYCL_INLINE_NAMESPACE(cl)
148 changes: 148 additions & 0 deletions sycl/include/sycl/ext/intel/experimental/bfloat16.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
//==--------- bfloat16.hpp ------- SYCL bfloat16 conversion ----------------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#pragma once

#include <CL/__spirv/spirv_ops.hpp>

__SYCL_INLINE_NAMESPACE(cl) {
namespace sycl {
namespace ext {
namespace intel {
namespace experimental {

class [[sycl_detail::uses_aspects(ext_intel_bf16_conversion)]] bfloat16 {
using storage_t = uint16_t;
storage_t value;

public:
bfloat16() = default;
bfloat16(const bfloat16 &) = default;
~bfloat16() = default;

// Explicit conversion functions
static storage_t from_float(const float &a) {
#if defined(__SYCL_DEVICE_ONLY__)
return __spirv_ConvertFToBF16INTEL(a);
#else
throw exception{errc::feature_not_supported,
"Bfloat16 conversion is not supported on host device"};
#endif
}
static float to_float(const storage_t &a) {
#if defined(__SYCL_DEVICE_ONLY__)
return __spirv_ConvertBF16ToFINTEL(a);
#else
throw exception{errc::feature_not_supported,
"Bfloat16 conversion is not supported on host device"};
#endif
}

// Direct initialization
bfloat16(const storage_t &a) : value(a) {}

// Implicit conversion from float to bfloat16
bfloat16(const float &a) { value = from_float(a); }

bfloat16 &operator=(const float &rhs) {
value = from_float(rhs);
return *this;
}

// Implicit conversion from bfloat16 to float
operator float() const { return to_float(value); }

// Get raw bits representation of bfloat16
operator storage_t() const { return value; }

// Logical operators (!,||,&&) are covered if we can cast to bool
explicit operator bool() { return to_float(value) != 0.0f; }

// Unary minus operator overloading
friend bfloat16 operator-(bfloat16 &lhs) {
return bfloat16{-to_float(lhs.value)};
}

// Increment and decrement operators overloading
#define OP(op) \
friend bfloat16 &operator op(bfloat16 &lhs) { \
float f = to_float(lhs.value); \
lhs.value = from_float(op f); \
return lhs; \
} \
friend bfloat16 operator op(bfloat16 &lhs, int) { \
bfloat16 old = lhs; \
operator op(lhs); \
return old; \
}
OP(++)
OP(--)
#undef OP

// Assignment operators overloading
#define OP(op) \
friend bfloat16 &operator op(bfloat16 &lhs, const bfloat16 &rhs) { \
float f = static_cast<float>(lhs); \
f op static_cast<float>(rhs); \
return lhs = f; \
} \
template <typename T> \
friend bfloat16 &operator op(bfloat16 &lhs, const T &rhs) { \
float f = static_cast<float>(lhs); \
f op static_cast<float>(rhs); \
return lhs = f; \
} \
template <typename T> friend T &operator op(T &lhs, const bfloat16 &rhs) { \
float f = static_cast<float>(lhs); \
f op static_cast<float>(rhs); \
return lhs = f; \
}
OP(+=)
OP(-=)
OP(*=)
OP(/=)
#undef OP

// Binary operators overloading
#define OP(type, op) \
friend type operator op(const bfloat16 &lhs, const bfloat16 &rhs) { \
return type{static_cast<float>(lhs) op static_cast<float>(rhs)}; \
} \
template <typename T> \
friend type operator op(const bfloat16 &lhs, const T &rhs) { \
return type{static_cast<float>(lhs) op static_cast<float>(rhs)}; \
} \
template <typename T> \
friend type operator op(const T &lhs, const bfloat16 &rhs) { \
return type{static_cast<float>(lhs) op static_cast<float>(rhs)}; \
}
OP(bfloat16, +)
OP(bfloat16, -)
OP(bfloat16, *)
OP(bfloat16, /)
OP(bool, ==)
OP(bool, !=)
OP(bool, <)
OP(bool, >)
OP(bool, <=)
OP(bool, >=)
#undef OP

// Bitwise(|,&,~,^), modulo(%) and shift(<<,>>) operations are not supported
// for floating-point types.
};

} // namespace experimental
} // namespace intel
} // namespace ext

namespace __SYCL2020_DEPRECATED("use 'ext::intel' instead") INTEL {
using namespace ext::intel;
}
} // namespace sycl
} // __SYCL_INLINE_NAMESPACE(cl)
51 changes: 51 additions & 0 deletions sycl/test/extensions/bfloat16.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// RUN: %clangxx -fsycl-device-only -S -Xclang -emit-llvm %s -o - | FileCheck %s

#include <sycl/sycl.hpp>
#include <sycl/ext/intel/experimental/bfloat16.hpp>

using sycl::ext::intel::experimental::bfloat16;

SYCL_EXTERNAL uint16_t some_bf16_intrinsic(uint16_t x, uint16_t y);

__attribute__((noinline))
float op(float a, float b) {
bfloat16 A {a};
// CHECK: [[A:%.*]] = tail call spir_func zeroext i16 @_Z27__spirv_ConvertFToBF16INTELf(float %a)
// CHECK-NOT: fptoui

bfloat16 B {b};
// CHECK: [[B:%.*]] = tail call spir_func zeroext i16 @_Z27__spirv_ConvertFToBF16INTELf(float %b)
// CHECK-NOT: fptoui

bfloat16 C = A + B;
// CHECK: [[A_float:%.*]] = tail call spir_func float @_Z27__spirv_ConvertBF16ToFINTELt(i16 zeroext [[A]])
// CHECK: [[B_float:%.*]] = tail call spir_func float @_Z27__spirv_ConvertBF16ToFINTELt(i16 zeroext [[B]])
// CHECK: [[Add:%.*]] = fadd float [[A_float]], [[B_float]]
// CHECK: [[C:%.*]] = tail call spir_func zeroext i16 @_Z27__spirv_ConvertFToBF16INTELf(float [[Add]])
// CHECK-NOT: uitofp
// CHECK-NOT: fptoui

bfloat16 D = some_bf16_intrinsic(A, C);
// CHECK: [[D:%.*]] = tail call spir_func zeroext i16 @_Z19some_bf16_intrinsictt(i16 zeroext [[A]], i16 zeroext [[C]])
// CHECK-NOT: uitofp
// CHECK-NOT: fptoui

return D;
// CHECK: [[RetVal:%.*]] = tail call spir_func float @_Z27__spirv_ConvertBF16ToFINTELt(i16 zeroext [[D]])
// CHECK: ret float [[RetVal]]
// CHECK-NOT: uitofp
// CHECK-NOT: fptoui
}

int main(int argc, char *argv[]) {
float data[3] = {7.0, 8.1, 0.0};
cl::sycl::queue deviceQueue;
cl::sycl::buffer<float, 1> buf{data, cl::sycl::range<1>{3}};

deviceQueue.submit([&](cl::sycl::handler &cgh) {
auto numbers = buf.get_access<cl::sycl::access::mode::read_write>(cgh);
cgh.single_task<class simple_kernel>(
[=]() { numbers[2] = op(numbers[0], numbers[1]); });
});
return 0;
}