diff --git a/sycl/include/CL/__spirv/spirv_ops.hpp b/sycl/include/CL/__spirv/spirv_ops.hpp index 5f2ee620de3c8..18ef03cc70607 100644 --- a/sycl/include/CL/__spirv/spirv_ops.hpp +++ b/sycl/include/CL/__spirv/spirv_ops.hpp @@ -594,6 +594,9 @@ extern SYCL_EXTERNAL void __spirv_ocl_prefetch(const __attribute__((opencl_global)) char *Ptr, size_t NumBytes) noexcept; +extern SYCL_EXTERNAL uint16_t __spirv_ConvertFToBF16INTEL(float) noexcept; +extern SYCL_EXTERNAL float __spirv_ConvertBF16ToFINTEL(uint16_t) noexcept; + #else // if !__SYCL_DEVICE_ONLY__ template diff --git a/sycl/include/CL/sycl/feature_test.hpp b/sycl/include/CL/sycl/feature_test.hpp index c3a426cdb81c4..4625cfa06fed7 100644 --- a/sycl/include/CL/sycl/feature_test.hpp +++ b/sycl/include/CL/sycl/feature_test.hpp @@ -23,6 +23,7 @@ namespace sycl { #ifndef SYCL_EXT_ONEAPI_MATRIX #define SYCL_EXT_ONEAPI_MATRIX 2 #endif +#define SYCL_EXT_INTEL_BF16_CONVERSION 1 } // namespace sycl } // __SYCL_INLINE_NAMESPACE(cl) diff --git a/sycl/include/sycl/ext/intel/experimental/bfloat16.hpp b/sycl/include/sycl/ext/intel/experimental/bfloat16.hpp new file mode 100644 index 0000000000000..7a74c33ab7229 --- /dev/null +++ b/sycl/include/sycl/ext/intel/experimental/bfloat16.hpp @@ -0,0 +1,148 @@ +//==--------- bfloat16.hpp ------- SYCL bfloat16 conversion ----------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +__SYCL_INLINE_NAMESPACE(cl) { +namespace sycl { +namespace ext { +namespace intel { +namespace experimental { + +class [[sycl_detail::uses_aspects(ext_intel_bf16_conversion)]] bfloat16 { + using storage_t = uint16_t; + storage_t value; + +public: + bfloat16() = default; + bfloat16(const bfloat16 &) = default; + ~bfloat16() = default; + + // Explicit conversion functions + static storage_t from_float(const float &a) { +#if defined(__SYCL_DEVICE_ONLY__) + return __spirv_ConvertFToBF16INTEL(a); +#else + throw exception{errc::feature_not_supported, + "Bfloat16 conversion is not supported on host device"}; +#endif + } + static float to_float(const storage_t &a) { +#if defined(__SYCL_DEVICE_ONLY__) + return __spirv_ConvertBF16ToFINTEL(a); +#else + throw exception{errc::feature_not_supported, + "Bfloat16 conversion is not supported on host device"}; +#endif + } + + // Direct initialization + bfloat16(const storage_t &a) : value(a) {} + + // Implicit conversion from float to bfloat16 + bfloat16(const float &a) { value = from_float(a); } + + bfloat16 &operator=(const float &rhs) { + value = from_float(rhs); + return *this; + } + + // Implicit conversion from bfloat16 to float + operator float() const { return to_float(value); } + + // Get raw bits representation of bfloat16 + operator storage_t() const { return value; } + + // Logical operators (!,||,&&) are covered if we can cast to bool + explicit operator bool() { return to_float(value) != 0.0f; } + + // Unary minus operator overloading + friend bfloat16 operator-(bfloat16 &lhs) { + return bfloat16{-to_float(lhs.value)}; + } + +// Increment and decrement operators overloading +#define OP(op) \ + friend bfloat16 &operator op(bfloat16 &lhs) { \ + float f = to_float(lhs.value); \ + lhs.value = from_float(op f); \ + return lhs; \ + } \ + friend bfloat16 operator op(bfloat16 &lhs, int) { \ + bfloat16 old = lhs; \ + operator op(lhs); \ + return old; \ + } + OP(++) + OP(--) +#undef OP + + // Assignment operators overloading +#define OP(op) \ + friend bfloat16 &operator op(bfloat16 &lhs, const bfloat16 &rhs) { \ + float f = static_cast(lhs); \ + f op static_cast(rhs); \ + return lhs = f; \ + } \ + template \ + friend bfloat16 &operator op(bfloat16 &lhs, const T &rhs) { \ + float f = static_cast(lhs); \ + f op static_cast(rhs); \ + return lhs = f; \ + } \ + template friend T &operator op(T &lhs, const bfloat16 &rhs) { \ + float f = static_cast(lhs); \ + f op static_cast(rhs); \ + return lhs = f; \ + } + OP(+=) + OP(-=) + OP(*=) + OP(/=) +#undef OP + +// Binary operators overloading +#define OP(type, op) \ + friend type operator op(const bfloat16 &lhs, const bfloat16 &rhs) { \ + return type{static_cast(lhs) op static_cast(rhs)}; \ + } \ + template \ + friend type operator op(const bfloat16 &lhs, const T &rhs) { \ + return type{static_cast(lhs) op static_cast(rhs)}; \ + } \ + template \ + friend type operator op(const T &lhs, const bfloat16 &rhs) { \ + return type{static_cast(lhs) op static_cast(rhs)}; \ + } + OP(bfloat16, +) + OP(bfloat16, -) + OP(bfloat16, *) + OP(bfloat16, /) + OP(bool, ==) + OP(bool, !=) + OP(bool, <) + OP(bool, >) + OP(bool, <=) + OP(bool, >=) +#undef OP + + // Bitwise(|,&,~,^), modulo(%) and shift(<<,>>) operations are not supported + // for floating-point types. +}; + +} // namespace experimental +} // namespace intel +} // namespace ext + +namespace __SYCL2020_DEPRECATED("use 'ext::intel' instead") INTEL { + using namespace ext::intel; +} +} // namespace sycl +} // __SYCL_INLINE_NAMESPACE(cl) diff --git a/sycl/test/extensions/bfloat16.cpp b/sycl/test/extensions/bfloat16.cpp new file mode 100644 index 0000000000000..f80e4fb7ef48d --- /dev/null +++ b/sycl/test/extensions/bfloat16.cpp @@ -0,0 +1,51 @@ +// RUN: %clangxx -fsycl-device-only -S -Xclang -emit-llvm %s -o - | FileCheck %s + +#include +#include + +using sycl::ext::intel::experimental::bfloat16; + +SYCL_EXTERNAL uint16_t some_bf16_intrinsic(uint16_t x, uint16_t y); + +__attribute__((noinline)) +float op(float a, float b) { + bfloat16 A {a}; +// CHECK: [[A:%.*]] = tail call spir_func zeroext i16 @_Z27__spirv_ConvertFToBF16INTELf(float %a) +// CHECK-NOT: fptoui + + bfloat16 B {b}; +// CHECK: [[B:%.*]] = tail call spir_func zeroext i16 @_Z27__spirv_ConvertFToBF16INTELf(float %b) +// CHECK-NOT: fptoui + + bfloat16 C = A + B; +// CHECK: [[A_float:%.*]] = tail call spir_func float @_Z27__spirv_ConvertBF16ToFINTELt(i16 zeroext [[A]]) +// CHECK: [[B_float:%.*]] = tail call spir_func float @_Z27__spirv_ConvertBF16ToFINTELt(i16 zeroext [[B]]) +// CHECK: [[Add:%.*]] = fadd float [[A_float]], [[B_float]] +// CHECK: [[C:%.*]] = tail call spir_func zeroext i16 @_Z27__spirv_ConvertFToBF16INTELf(float [[Add]]) +// CHECK-NOT: uitofp +// CHECK-NOT: fptoui + + bfloat16 D = some_bf16_intrinsic(A, C); +// CHECK: [[D:%.*]] = tail call spir_func zeroext i16 @_Z19some_bf16_intrinsictt(i16 zeroext [[A]], i16 zeroext [[C]]) +// CHECK-NOT: uitofp +// CHECK-NOT: fptoui + + return D; +// CHECK: [[RetVal:%.*]] = tail call spir_func float @_Z27__spirv_ConvertBF16ToFINTELt(i16 zeroext [[D]]) +// CHECK: ret float [[RetVal]] +// CHECK-NOT: uitofp +// CHECK-NOT: fptoui +} + +int main(int argc, char *argv[]) { + float data[3] = {7.0, 8.1, 0.0}; + cl::sycl::queue deviceQueue; + cl::sycl::buffer buf{data, cl::sycl::range<1>{3}}; + + deviceQueue.submit([&](cl::sycl::handler &cgh) { + auto numbers = buf.get_access(cgh); + cgh.single_task( + [=]() { numbers[2] = op(numbers[0], numbers[1]); }); + }); + return 0; +}