From 92b20f0e2c631b4307225cb2b30a1cd66894b768 Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Wed, 26 Jan 2022 17:02:14 +0800 Subject: [PATCH 1/3] [Matrix][SYCL] add support for bf16's wi_element --- .../sycl/ext/oneapi/matrix/matrix-jit.hpp | 270 ++++++++++++++++++ 1 file changed, 270 insertions(+) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp index 14466fd5fafb4..d862e0dcdab16 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp @@ -451,6 +451,276 @@ class wi_element { } }; +template +class wi_element { + joint_matrix &M; + std::size_t idx; + +public: + wi_element(joint_matrix &Mat, + std::size_t i) + : M(Mat), idx(i) {} + operator uint16_t() { +#ifdef __SYCL_DEVICE_ONLY__ + return __spirv_VectorExtractDynamic(M.spvm, idx); +#else + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + explicit operator bool() { +#ifdef __SYCL_DEVICE_ONLY__ + return __spirv_VectorExtractDynamic(M.spvm, idx) != + static_cast(0); +#else + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + wi_element &operator=(const uint16_t &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + M.spvm = __spirv_VectorInsertDynamic(M.spvm, rhs, idx); + return *this; +#else + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + wi_element & + operator=(const wi_element &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + M.spvm = __spirv_VectorInsertDynamic( + M.spvm, __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx), idx); + return *this; +#else + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + static float make_fp32(uint16_t x) { + unsigned int y = x; + y = y << 16; + float *res = reinterpret_cast(&y); + return *res; + } + + static uint16_t make_bf16(float x) { + int *res = reinterpret_cast(&x); + *res = *res >> 16; + return (uint16_t)*res; + } + + friend uint16_t + operator+(const wi_element &lhs, + const uint16_t &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + return make_bf16( + make_fp32(__spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx)) + + make_fp32(rhs)); +#else + (void)lhs; + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + wi_element &operator+=(const uint16_t &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + M.spvm = __spirv_VectorInsertDynamic( + M.spvm, + make_bf16(make_fp32(__spirv_VectorExtractDynamic(M.spvm, idx)) + + make_fp32(rhs)), + idx); + return *this; +#else + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + friend uint16_t + operator-(const wi_element &lhs, + const uint16_t &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + return make_bf16( + make_fp32(__spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx)) - + make_fp32(rhs)); +#else + (void)lhs; + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + wi_element &operator-=(const uint16_t &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + M.spvm = __spirv_VectorInsertDynamic( + M.spvm, + make_bf16(make_fp32(__spirv_VectorExtractDynamic(M.spvm, idx)) - + make_fp32(rhs)), + idx); + return *this; +#else + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + friend uint16_t + operator*(const wi_element &lhs, + const uint16_t &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + return make_bf16( + make_fp32(__spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx)) * + make_fp32(rhs)); +#else + (void)lhs; + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + wi_element &operator*=(const uint16_t &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + M.spvm = __spirv_VectorInsertDynamic( + M.spvm, + make_bf16(make_fp32(__spirv_VectorExtractDynamic(M.spvm, idx)) * + make_fp32(rhs)), + idx); + return *this; +#else + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + friend uint16_t + operator/(const wi_element &lhs, + const uint16_t &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + return make_bf16( + make_fp32(__spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx)) / + make_fp32(rhs)); +#else + (void)lhs; + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + wi_element &operator/=(const uint16_t &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + M.spvm = __spirv_VectorInsertDynamic( + M.spvm, + make_bf16(make_fp32(__spirv_VectorExtractDynamic(M.spvm, idx)) / + make_fp32(rhs)), + idx); + return *this; +#else + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + friend bool + operator<(const wi_element &lhs, + const uint16_t &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + return make_fp32(__spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx)) < + make_fp32(rhs); +#else + (void)lhs; + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + friend bool + operator<=(const wi_element &lhs, + const uint16_t &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + return make_fp32(__spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx)) <= + make_fp32(rhs); +#else + (void)lhs; + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + friend bool + operator>(const wi_element &lhs, + const uint16_t &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + return make_fp32(__spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx)) > + make_fp32(rhs); +#else + (void)lhs; + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + friend bool + operator>=(const wi_element &lhs, + const uint16_t &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + return make_fp32(__spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx)) >= + make_fp32(rhs); +#else + (void)lhs; + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + friend bool + operator==(const wi_element &lhs, + const uint16_t &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + return make_fp32(__spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx)) == + make_fp32(rhs); +#else + (void)lhs; + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + friend bool + operator!=(const wi_element &lhs, + const uint16_t &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + return make_fp32(__spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx)) != + make_fp32(rhs); +#else + (void)lhs; + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } +}; + template class wi_slice { From 22542aa1b07d78189ca7c65d3278aa8cca73e6df Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Thu, 27 Jan 2022 17:41:21 +0800 Subject: [PATCH 2/3] Add some comments --- sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp index d862e0dcdab16..941dd7ac243e9 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp @@ -451,6 +451,11 @@ class wi_element { } }; +// uint16_t here represents bf16 type as we have been doing for the +// previous matrix implementations. Since the AMX and DPAS implementations don't +// support uint16_t, this should raise no problem. Our plan is to move towards +// SYCL bfloat16 once it makes itself to the specification (it is experimental +// right now). template class wi_element { joint_matrix &M; @@ -503,6 +508,9 @@ class wi_element { #endif // __SYCL_DEVICE_ONLY__ } + // For now we use the following functions for convertion(bf16=>fp32, + // fp32=>bf16) as a workaround. In the future we will use + // __spirv_ConvertFToBF16INTEL and __spirv_ConvertBF16ToFINTEL static float make_fp32(uint16_t x) { unsigned int y = x; y = y << 16; From a5f902a8b814ef2eb380132381dcbe6b8484823e Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Sat, 29 Jan 2022 01:07:35 +0800 Subject: [PATCH 3/3] Modify some comments --- .../sycl/ext/oneapi/matrix/matrix-jit.hpp | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp index 941dd7ac243e9..b3da11e7c439d 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp @@ -451,11 +451,12 @@ class wi_element { } }; -// uint16_t here represents bf16 type as we have been doing for the -// previous matrix implementations. Since the AMX and DPAS implementations don't -// support uint16_t, this should raise no problem. Our plan is to move towards -// SYCL bfloat16 once it makes itself to the specification (it is experimental -// right now). +// Note that similarly to the other matrix functions, uint16_t is used here to +// represent bf16 type. Since the AMX and DPAS implementations don't support +// uint16_t, this interpretation is possible. This design choice was made before +// the introduction of SYCL experimental bfloat16 type. Our plan is to move +// towards using the SYCL bfloat16. But since it is still experimental, we will +// probably keep both uint16 interpretation and SYCL bfloat16. template class wi_element { joint_matrix &M; @@ -508,9 +509,10 @@ class wi_element { #endif // __SYCL_DEVICE_ONLY__ } - // For now we use the following functions for convertion(bf16=>fp32, - // fp32=>bf16) as a workaround. In the future we will use - // __spirv_ConvertFToBF16INTEL and __spirv_ConvertBF16ToFINTEL + // We use here the following functions for conversion (bf16=>fp32 and + // fp32=>bf16). This is a workaround until we are able to use + // __spirv_ConvertFToBF16INTEL and __spirv_ConvertBF16ToFINTEL once these are + // supported in the CPU backend static float make_fp32(uint16_t x) { unsigned int y = x; y = y << 16;