diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores-legacy.hpp similarity index 99% rename from sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp rename to sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores-legacy.hpp index 37120646ecad1..ce147d5bc311a 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores-legacy.hpp @@ -1,4 +1,4 @@ -//===---- matrix-tensorcore.hpp - SYCL tensor cores matrix ----*- C++ -*---===// +//===-------------- matrix-tensorcores-legacy.hpp - -----------*- C++ -*---===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp new file mode 100644 index 0000000000000..5e9faba8e94ec --- /dev/null +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp @@ -0,0 +1,639 @@ + +//===-------- matrix-tensorcores.hpp - matrix ext impl ---*- C++ -*-------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// ===-------------------------------------------------------------------=== // + +#pragma once +#include + +namespace sycl { +__SYCL_INLINE_VER_NAMESPACE(_V1) { +namespace ext { +namespace oneapi { +namespace experimental { +namespace matrix { + +enum class use { a, b, accumulator }; + +enum class layout { row_major, col_major, dynamic }; + +namespace precision { +class tf32 { + tf32() = delete; +}; +} // namespace precision + +template +struct joint_matrix; + +template +class wi_data { + + joint_matrix &jm; + + wi_data(joint_matrix &_jm) : jm(_jm){}; + + template + friend wi_data + get_wi_data(Grp, + joint_matrix &); + +public: + size_t length() { return jm.cuda_impl.wi_marray.size(); }; + + decltype(auto) operator[](size_t i) { return (jm.cuda_impl.wi_marray[i]); }; +}; + +} // namespace matrix +} // namespace experimental + +namespace detail { + +template +struct joint_matrix_cuda; + +#define __SYCL_JOINT_MATRIX_OVERLOAD_ARR(TYPE, USE, M, N, SIZE) \ + template \ + struct joint_matrix_cuda< \ + TYPE, sycl::ext::oneapi::experimental::matrix::use::USE, M, N, Layout, \ + typename std::enable_if_t< \ + Layout == \ + sycl::ext::oneapi::experimental::matrix::layout::row_major || \ + Layout == \ + sycl::ext::oneapi::experimental::matrix::layout::col_major>> { \ + marray wi_marray; \ + }; + +// m8n32k16 +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(sycl::ext::oneapi::bfloat16, a, 8, 16, 4) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(sycl::ext::oneapi::bfloat16, b, 16, 32, 16) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, a, 8, 16, 16) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, b, 16, 32, 16) + +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, a, 8, 16, 4) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, b, 16, 32, 16) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, a, 8, 16, 4) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, b, 16, 32, 16) +// m32n8k16 +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(sycl::ext::oneapi::bfloat16, a, 32, 16, 16) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(sycl::ext::oneapi::bfloat16, b, 16, 8, 4) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, a, 32, 16, 16) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, b, 16, 8, 16) + +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, a, 32, 16, 16) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, b, 16, 8, 4) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, a, 32, 16, 16) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, b, 16, 8, 4) +// m16n16k16 +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(sycl::ext::oneapi::bfloat16, a, 16, 16, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(sycl::ext::oneapi::bfloat16, b, 16, 16, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, a, 16, 16, 16) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, b, 16, 16, 16) + +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, a, 16, 16, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, b, 16, 16, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, a, 16, 16, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, b, 16, 16, 8) +// m8n8k4 double only +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(double, a, 8, 4, 1) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(double, b, 4, 8, 1) + +#undef __SYCL_JOINT_MATRIX_OVERLOAD_ARR + +#define __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(TYPE, M, N, SIZE) \ + template <> \ + struct joint_matrix_cuda< \ + TYPE, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, \ + sycl::ext::oneapi::experimental::matrix::layout::dynamic> { \ + marray wi_marray; \ + }; + +__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(half, 8, 32, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(float, 8, 32, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(int32_t, 8, 32, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(half, 32, 8, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(float, 32, 8, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(int32_t, 32, 8, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(half, 16, 16, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(float, 16, 16, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(int32_t, 16, 16, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(double, 8, 8, 2) + +#undef __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC + +#define __SYCL_JOINT_MATRIX_OVERLOAD_ARR_PRECISION(PRECISION, USE, M, N, TYPE, \ + SIZE) \ + template \ + struct joint_matrix_cuda< \ + PRECISION, sycl::ext::oneapi::experimental::matrix::use::USE, M, N, \ + Layout, \ + typename std::enable_if_t< \ + Layout == \ + sycl::ext::oneapi::experimental::matrix::layout::row_major || \ + Layout == \ + sycl::ext::oneapi::experimental::matrix::layout::col_major>> { \ + marray wi_marray; \ + }; +// m16n16k8 tf32 only +__SYCL_JOINT_MATRIX_OVERLOAD_ARR_PRECISION( + sycl::ext::oneapi::experimental::matrix::precision::tf32, a, 16, 8, float, + 4) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR_PRECISION( + sycl::ext::oneapi::experimental::matrix::precision::tf32, b, 8, 16, float, + 4) + +#undef __SYCL_JOINT_MATRIX_OVERLOAD_ARR_PRECISION +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +template +constexpr int get_layout_id(); + +template <> +constexpr int +get_layout_id() { + return 0; +} + +template <> +constexpr int +get_layout_id() { + return 1; +} + +template +void load_accumulator_layoutT( + joint_matrix_cuda< + S, sycl::ext::oneapi::experimental::matrix::use::accumulator, NumRows, + NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &res, + multi_ptr src, size_t stride) { + if constexpr (std::is_same_v) { + auto destptr = reinterpret_cast(&res.wi_marray); + if constexpr (NumRows == 16 && NumCols == 16) { + __imma_m16n16k16_ld_c(destptr, src.get(), stride, + get_layout_id()); + } else if constexpr (NumRows == 8 && NumCols == 32) { + __imma_m8n32k16_ld_c(destptr, src.get(), stride, get_layout_id()); + } else if constexpr (NumRows == 32 && NumCols == 8) { + __imma_m32n8k16_ld_c(destptr, src.get(), stride, get_layout_id()); + } + } else if constexpr (std::is_same_v) { + auto dstptr = reinterpret_cast(&res.wi_marray); + if constexpr (NumRows == 16 && NumCols == 16) { + __hmma_m16n16k16_ld_c_f32(dstptr, src.get(), stride, + get_layout_id()); + } else if constexpr (NumRows == 8 && NumCols == 32) { + __hmma_m8n32k16_ld_c_f32(dstptr, src.get(), stride, + get_layout_id()); + } else if constexpr (NumRows == 32 && NumCols == 8) { + __hmma_m32n8k16_ld_c_f32(dstptr, src.get(), stride, + get_layout_id()); + } + } else if constexpr (std::is_same_v) { + auto tileptr = reinterpret_cast(src.get()); + auto dstptr = reinterpret_cast(&res.wi_marray); + if constexpr (NumRows == 32 && NumCols == 8) { + __hmma_m32n8k16_ld_c_f16(dstptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 8 && NumCols == 32) { + __hmma_m8n32k16_ld_c_f16(dstptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 16 && NumCols == 16) { + __hmma_m16n16k16_ld_c_f16(dstptr, tileptr, stride, + get_layout_id()); + } + } else if constexpr (std::is_same_v) { + __dmma_m8n8k4_ld_c(reinterpret_cast(&res.wi_marray), src.get(), + stride, get_layout_id()); + } +}; + +template +void load_accumulator_cuda( + joint_matrix_cuda< + S, sycl::ext::oneapi::experimental::matrix::use::accumulator, NumRows, + NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &res, + multi_ptr src, size_t stride, + sycl::ext::oneapi::experimental::matrix::layout Layout) { + switch (Layout) { + case sycl::ext::oneapi::experimental::matrix::layout::row_major: + load_accumulator_layoutT< + sycl::ext::oneapi::experimental::matrix::layout::row_major>(res, src, + stride); + break; + case sycl::ext::oneapi::experimental::matrix::layout::col_major: + load_accumulator_layoutT< + sycl::ext::oneapi::experimental::matrix::layout::col_major>(res, src, + stride); + break; + default: + assert(false && "Invalid layout specified!"); + } +} + +template < + typename S, typename T, size_t NumRows, size_t NumCols, + sycl::ext::oneapi::experimental::matrix::use Use, + sycl::ext::oneapi::experimental::matrix::layout Layout, + access::address_space Space, access::decorated IsDecorated, + std::enable_if_t< + Layout == sycl::ext::oneapi::experimental::matrix::layout::row_major || + Layout == + sycl::ext::oneapi::experimental::matrix::layout::col_major, + bool> = true> +void load_multiplicand_cuda( + joint_matrix_cuda &res, + multi_ptr src, size_t stride) { + if constexpr (std::is_same_v) { + auto tileptr = reinterpret_cast(src.get()); + auto destptr = reinterpret_cast(&res.wi_marray); + if constexpr (NumRows == 16 && NumCols == 16) { + if constexpr (Use == sycl::ext::oneapi::experimental::matrix::use::a) { + __mma_bf16_m16n16k16_ld_a(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (Use == + sycl::ext::oneapi::experimental::matrix::use::b) { + __mma_bf16_m16n16k16_ld_b(destptr, tileptr, stride, + get_layout_id()); + } + } else if constexpr (NumRows == 8 && NumCols == 16) { + __mma_bf16_m8n32k16_ld_a(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 16 && NumCols == 32) { + __mma_bf16_m8n32k16_ld_b(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 32 && NumCols == 16) { + __mma_bf16_m32n8k16_ld_a(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 16 && NumCols == 8) { + __mma_bf16_m32n8k16_ld_b(destptr, tileptr, stride, + get_layout_id()); + } + } else if constexpr (std::is_same_v) { + auto tileptr = reinterpret_cast(src.get()); + auto destptr = reinterpret_cast(&res.wi_marray); + if constexpr (NumRows == 16 && NumCols == 16) { + if constexpr (Use == sycl::ext::oneapi::experimental::matrix::use::a) { + __imma_m16n16k16_ld_a_u8(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (Use == + sycl::ext::oneapi::experimental::matrix::use::b) { + __imma_m16n16k16_ld_b_u8(destptr, tileptr, stride, + get_layout_id()); + } + } else if constexpr (NumRows == 8 && NumCols == 16) { + __imma_m8n32k16_ld_a_u8(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 16 && NumCols == 32) { + __imma_m8n32k16_ld_b_u8(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 32 && NumCols == 16) { + __imma_m32n8k16_ld_a_u8(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 16 && NumCols == 8) { + __imma_m32n8k16_ld_b_u8(destptr, tileptr, stride, + get_layout_id()); + } + } else if constexpr (std::is_same_v) { + auto tileptr = reinterpret_cast(src.get()); + auto destptr = reinterpret_cast(&res.wi_marray); + if constexpr (NumRows == 16 && NumCols == 16) { + if constexpr (Use == sycl::ext::oneapi::experimental::matrix::use::a) { + __imma_m16n16k16_ld_a_s8(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (Use == + sycl::ext::oneapi::experimental::matrix::use::b) { + __imma_m16n16k16_ld_b_s8(destptr, tileptr, stride, + get_layout_id()); + } + } else if constexpr (NumRows == 8 && NumCols == 16) { + __imma_m8n32k16_ld_a_s8(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 16 && NumCols == 32) { + __imma_m8n32k16_ld_b_s8(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 32 && NumCols == 16) { + __imma_m32n8k16_ld_a_s8(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 16 && NumCols == 8) { + __imma_m32n8k16_ld_b_s8(destptr, tileptr, stride, + get_layout_id()); + } + } else if constexpr (std::is_same_v) { + auto tileptr = reinterpret_cast(src.get()); + auto dstptr = reinterpret_cast(&res.wi_marray); + if constexpr (NumRows == 16 && NumCols == 16) { + if constexpr (Use == sycl::ext::oneapi::experimental::matrix::use::a) { + __hmma_m16n16k16_ld_a(dstptr, tileptr, stride, get_layout_id()); + } else if constexpr (Use == + sycl::ext::oneapi::experimental::matrix::use::b) { + __hmma_m16n16k16_ld_b(dstptr, tileptr, stride, get_layout_id()); + } + } else if constexpr (NumRows == 8 && NumCols == 16) { + __hmma_m8n32k16_ld_a(dstptr, tileptr, stride, get_layout_id()); + } else if constexpr (NumRows == 16 && NumCols == 32) { + __hmma_m8n32k16_ld_b(dstptr, tileptr, stride, get_layout_id()); + } else if constexpr (NumRows == 32 && NumCols == 16) { + __hmma_m32n8k16_ld_a(dstptr, tileptr, stride, get_layout_id()); + } else if constexpr (NumRows == 16 && NumCols == 8) { + __hmma_m32n8k16_ld_b(dstptr, tileptr, stride, get_layout_id()); + } + + } else if constexpr (std::is_same_v) { + auto tileptr = reinterpret_cast(src.get()); + auto dstptr = reinterpret_cast(&res.wi_marray); + if constexpr (NumRows == 16 && NumCols == 8) { + __mma_tf32_m16n16k8_ld_a(dstptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 8 && NumCols == 16) { + __mma_tf32_m16n16k8_ld_b(dstptr, tileptr, stride, + get_layout_id()); + } + } else if constexpr (std::is_same_v) { + auto dstptr = reinterpret_cast(&res.wi_marray); + if constexpr (Use == sycl::ext::oneapi::experimental::matrix::use::a) { + __dmma_m8n8k4_ld_a(dstptr, src.get(), stride, get_layout_id()); + } else if constexpr (Use == + sycl::ext::oneapi::experimental::matrix::use::b) { + __dmma_m8n8k4_ld_b(dstptr, src.get(), stride, get_layout_id()); + } + } +} + +template +void store_layoutT( + joint_matrix_cuda< + T, sycl::ext::oneapi::experimental::matrix::use::accumulator, NumRows, + NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &src, + multi_ptr dst, size_t stride) { + if constexpr (NumRows == 16 && NumCols == 16) { + if constexpr (std::is_same_v) { + __hmma_m16n16k16_st_c_f32(dst.get(), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); + } else if constexpr (std::is_same_v) { + __imma_m16n16k16_st_c_i32(dst.get(), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); + } else if constexpr (std::is_same_v) { + __hmma_m16n16k16_st_c_f16(reinterpret_cast(dst.get()), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); + } + } else if constexpr (NumRows == 8 && NumCols == 32) { + if constexpr (std::is_same_v) { + __hmma_m8n32k16_st_c_f32(dst.get(), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); + } else if constexpr (std::is_same_v) { + __imma_m8n32k16_st_c_i32(dst.get(), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); + } else if constexpr (std::is_same_v) { + __hmma_m8n32k16_st_c_f16(reinterpret_cast(dst.get()), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); + } + } else if constexpr (NumRows == 32 && NumCols == 8) { + if constexpr (std::is_same_v) { + __hmma_m32n8k16_st_c_f32(dst.get(), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); + } else if constexpr (std::is_same_v) { + __imma_m32n8k16_st_c_i32(dst.get(), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); + } else if constexpr (std::is_same_v) { + __hmma_m32n8k16_st_c_f16(reinterpret_cast(dst.get()), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); + } + } else if constexpr (std::is_same_v) { + __dmma_m8n8k4_st_c_f64(dst.get(), + reinterpret_cast(&src.wi_marray), stride, + get_layout_id()); + } +} + +template +void joint_matrix_store_cuda( + joint_matrix_cuda< + T, sycl::ext::oneapi::experimental::matrix::use::accumulator, NumRows, + NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &src, + multi_ptr dst, size_t stride, + sycl::ext::oneapi::experimental::matrix::layout Layout) { + switch (Layout) { + case sycl::ext::oneapi::experimental::matrix::layout::row_major: + store_layoutT( + src, dst, stride); + break; + case sycl::ext::oneapi::experimental::matrix::layout::col_major: + store_layoutT( + src, dst, stride); + break; + default: + assert(false && "Invalid layout specified!"); + } +} + +template +constexpr int get_layout_pair_id(); + +template <> +constexpr int get_layout_pair_id< + sycl::ext::oneapi::experimental::matrix::layout::row_major, + sycl::ext::oneapi::experimental::matrix::layout::row_major>() { + return 0; +} + +template <> +constexpr int get_layout_pair_id< + sycl::ext::oneapi::experimental::matrix::layout::row_major, + sycl::ext::oneapi::experimental::matrix::layout::col_major>() { + return 1; +} + +template <> +constexpr int get_layout_pair_id< + sycl::ext::oneapi::experimental::matrix::layout::col_major, + sycl::ext::oneapi::experimental::matrix::layout::row_major>() { + return 2; +} + +template <> +constexpr int get_layout_pair_id< + sycl::ext::oneapi::experimental::matrix::layout::col_major, + sycl::ext::oneapi::experimental::matrix::layout::col_major>() { + return 3; +} + +template < + typename Tm, typename Tc, std::size_t M, std::size_t K, std::size_t N, + sycl::ext::oneapi::experimental::matrix::layout LayoutA, + sycl::ext::oneapi::experimental::matrix::layout LayoutB, + std::enable_if_t< + (LayoutA == + sycl::ext::oneapi::experimental::matrix::layout::row_major || + LayoutA == + sycl::ext::oneapi::experimental::matrix::layout::col_major) && + (LayoutB == + sycl::ext::oneapi::experimental::matrix::layout::row_major || + LayoutB == + sycl::ext::oneapi::experimental::matrix::layout::col_major), + bool> = true> +void joint_matrix_mad_cuda( + joint_matrix_cuda< + Tc, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, + sycl::ext::oneapi::experimental::matrix::layout::dynamic> &D, + joint_matrix_cuda &A, + joint_matrix_cuda &B, + joint_matrix_cuda< + Tc, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, + sycl::ext::oneapi::experimental::matrix::layout::dynamic> &C) { + if constexpr (M == 16 && N == 16 && K == 16) { + if constexpr (std::is_same_v) { + auto ptrA = reinterpret_cast(&A.wi_marray); + auto ptrB = reinterpret_cast(&B.wi_marray); + auto ptrC = reinterpret_cast(&C.wi_marray); + auto ptrD = reinterpret_cast(&D.wi_marray); + if constexpr (std::is_same_v) { + __imma_m16n16k16_mma_s8(ptrD, ptrA, ptrB, ptrC, + get_layout_pair_id(), 0); + } else if constexpr (std::is_same_v) { + __imma_m16n16k16_mma_u8(ptrD, ptrA, ptrB, ptrC, + get_layout_pair_id(), 0); + } + } else if constexpr (std::is_same_v) { + auto ptrA = reinterpret_cast(&A.wi_marray); + auto ptrB = reinterpret_cast(&B.wi_marray); + if constexpr (std::is_same_v) { + __hmma_m16n16k16_mma_f32f32( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + + } else if constexpr (std::is_same_v) { + __hmma_m16n16k16_mma_f16f16( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } + } else if constexpr (std::is_same_v) { + __mma_bf16_m16n16k16_mma_f32( + reinterpret_cast(&D.wi_marray), + reinterpret_cast(&A.wi_marray), + reinterpret_cast(&B.wi_marray), + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } + } else if constexpr (M == 8 && N == 32 && K == 16) { + if constexpr (std::is_same_v) { + auto ptrA = reinterpret_cast(&A.wi_marray); + auto ptrB = reinterpret_cast(&B.wi_marray); + auto ptrC = reinterpret_cast(&C.wi_marray); + auto ptrD = reinterpret_cast(&D.wi_marray); + if constexpr (std::is_same_v) { + __imma_m8n32k16_mma_s8(ptrD, ptrA, ptrB, ptrC, + get_layout_pair_id(), 0); + } else if constexpr (std::is_same_v) { + __imma_m8n32k16_mma_u8(ptrD, ptrA, ptrB, ptrC, + get_layout_pair_id(), 0); + } + } else if constexpr (std::is_same_v) { + auto ptrA = reinterpret_cast(&A.wi_marray); + auto ptrB = reinterpret_cast(&B.wi_marray); + if constexpr (std::is_same_v) { + __hmma_m8n32k16_mma_f32f32( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } else if constexpr (std::is_same_v) { + __hmma_m8n32k16_mma_f16f16( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } + } else if constexpr (std::is_same_v) { + __mma_bf16_m8n32k16_mma_f32( + reinterpret_cast(&D.wi_marray), + reinterpret_cast(&A.wi_marray), + reinterpret_cast(&B.wi_marray), + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } + } else if constexpr (M == 32 && N == 8 && K == 16) { + if constexpr (std::is_same_v) { + auto ptrA = reinterpret_cast(&A.wi_marray); + auto ptrB = reinterpret_cast(&B.wi_marray); + auto ptrC = reinterpret_cast(&C.wi_marray); + auto ptrD = reinterpret_cast(&D.wi_marray); + if constexpr (std::is_same_v) { + __imma_m32n8k16_mma_s8(ptrD, ptrA, ptrB, ptrC, + get_layout_pair_id(), 0); + } else if constexpr (std::is_same_v) { + __imma_m32n8k16_mma_u8(ptrD, ptrA, ptrB, ptrC, + get_layout_pair_id(), 0); + } + } else if constexpr (std::is_same_v) { + __mma_bf16_m32n8k16_mma_f32( + reinterpret_cast(&D.wi_marray), + reinterpret_cast(&A.wi_marray), + reinterpret_cast(&B.wi_marray), + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } else if constexpr (std::is_same_v) { + auto ptrA = reinterpret_cast(&A.wi_marray); + auto ptrB = reinterpret_cast(&B.wi_marray); + if constexpr (std::is_same_v) { + __hmma_m32n8k16_mma_f32f32( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } else if constexpr (std::is_same_v) { + __hmma_m32n8k16_mma_f16f16( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } + } + } else if constexpr (M == 16 && N == 16 && K == 8) { + __mma_tf32_m16n16k8_mma_f32(reinterpret_cast(&D.wi_marray), + reinterpret_cast(&A.wi_marray), + reinterpret_cast(&B.wi_marray), + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } else if constexpr (std::is_same_v) { + __dmma_m8n8k4_mma_f64(reinterpret_cast(&D.wi_marray), + reinterpret_cast(&A.wi_marray), + reinterpret_cast(&B.wi_marray), + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } +} + +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + +} // namespace detail +} // namespace oneapi +} // namespace ext +} // __SYCL_INLINE_VER_NAMESPACE(_V1) +} // namespace sycl diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp new file mode 100644 index 0000000000000..bbdbdfc2f71b5 --- /dev/null +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -0,0 +1,221 @@ +//===------- matrix-unified.hpp - SYCL matrix extension ----*- C++ -*------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// ===--------------------------------------------------------------------=== // + +#pragma once +#include + +namespace sycl { +__SYCL_INLINE_VER_NAMESPACE(_V1) { +namespace ext { +namespace oneapi { +namespace experimental { +namespace matrix { + +template +struct joint_matrix { + +#if defined(__SYCL_DEVICE_ONLY__) && defined(__SPIR__) + // TODO: Intel case here: we use the ext_oneapi_cuda case also for the host, + // because the Intel SPIRV functions will not be host compilable. +#else + sycl::ext::oneapi::detail::joint_matrix_cuda + cuda_impl; +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__SPIR__) + + joint_matrix() { +#ifndef __SYCL_DEVICE_ONLY__ + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif + } +}; + +template +inline __SYCL_ALWAYS_INLINE wi_data +get_wi_data(Group sg, joint_matrix &jm) { +#if defined(__SYCL_DEVICE_ONLY__) +#if defined(__NVPTX__) + std::ignore = sg; + return wi_data(jm); +#else + // TODO add Intel impl. +#endif // defined(__NVPTX__) +#endif // defined(__SYCL_DEVICE_ONLY__) +} + +template +inline __SYCL_ALWAYS_INLINE void +joint_matrix_fill(Group sg, + joint_matrix &res, + const T2 &v) { +#if defined(__SYCL_DEVICE_ONLY__) +#if defined(__NVPTX__) + std::ignore = sg; + res.cuda_impl.wi_marray = v; +#endif // defined(__NVPTX__) +#else + std::ignore = sg; + std::ignore = res; + std::ignore = v; + throw runtime_error( + "This version of the matrix extension is only currently supported on " + "Nvidia devices", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) +} + +template < + typename Group, typename S, typename T, size_t NumRows, size_t NumCols, + access::address_space Space, access::decorated IsDecorated, + std::enable_if_t>::value, bool> = + true> +inline __SYCL_ALWAYS_INLINE void joint_matrix_load( + Group sg, + joint_matrix &res, + multi_ptr src, size_t stride, + sycl::ext::oneapi::experimental::matrix::layout Layout) { +#if defined(__SYCL_DEVICE_ONLY__) +#if defined(__NVPTX__) + std::ignore = sg; + sycl::ext::oneapi::detail::load_accumulator_cuda(res.cuda_impl, src, stride, + Layout); +#endif // defined(__NVPTX__) +#else + std::ignore = sg; + std::ignore = res; + std::ignore = src; + std::ignore = stride; + throw runtime_error( + "This version of the matrix extension is only currently supported on " + "Nvidia devices", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) +} + +template < + typename Group, typename S, typename T, use Use, size_t NumRows, + size_t NumCols, matrix::layout Layout, access::address_space Space, + access::decorated IsDecorated, + std::enable_if_t>::value || + (std::is_same::value && + std::is_same, float>::value), + bool> = true> +inline __SYCL_ALWAYS_INLINE void +joint_matrix_load(Group sg, + joint_matrix &res, + multi_ptr src, size_t stride) { +#if defined(__SYCL_DEVICE_ONLY__) +#if defined(__NVPTX__) + std::ignore = sg; + sycl::ext::oneapi::detail::load_multiplicand_cuda( + res.cuda_impl, src, stride); +#endif // defined(__NVPTX__) +#else + std::ignore = sg; + std::ignore = res; + std::ignore = src; + std::ignore = stride; + throw runtime_error( + "This version of the matrix extension is only currently supported on " + "Nvidia devices", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) +} + +template +inline __SYCL_ALWAYS_INLINE void joint_matrix_store( + Group sg, + joint_matrix &src, + multi_ptr dst, size_t stride, + sycl::ext::oneapi::experimental::matrix::layout Layout) { +#if defined(__SYCL_DEVICE_ONLY__) +#if defined(__NVPTX__) + std::ignore = sg; + sycl::ext::oneapi::detail::joint_matrix_store_cuda(src.cuda_impl, dst, + stride, Layout); +#endif // defined(__NVPTX__) +#else + std::ignore = sg; + std::ignore = src; + std::ignore = dst; + std::ignore = stride; + throw runtime_error( + "This version of the matrix extension is only currently supported on " + "Nvidia devices", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) +} + +template +inline __SYCL_ALWAYS_INLINE + joint_matrix + joint_matrix_mad( + Group sg, joint_matrix &A, + joint_matrix &B, + joint_matrix + &C) { +#if defined(__SYCL_DEVICE_ONLY__) +#if defined(__NVPTX__) + std::ignore = sg; + if constexpr (std::is_same::value) { + joint_matrix + D; + sycl::ext::oneapi::detail::joint_matrix_mad_cuda( + D.cuda_impl, A.cuda_impl, B.cuda_impl, C.cuda_impl); + return D; + } else { + assert(false && "Ta != Tb : In the CUDA backend joint_matrix_mad " + "requires that joint_matrix data types Ta and Tb match"); + } +#endif // defined(__NVPTX__) +#else + std::ignore = sg; + std::ignore = A; + std::ignore = B; + std::ignore = C; + throw runtime_error( + "This version of the matrix extension is only currently supported on " + "Nvidia devices", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) +} + +// This function rounds the bottom 13 bits up or down, and then zeros out the +// bottom bits +inline __SYCL_ALWAYS_INLINE float round_to_tf32(float &a) { +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + int32_t tmp_int = __nvvm_f2tf32_rna(a); + return __nvvm_bitcast_i2f(tmp_int); +#else + uint32_t tmp_uint = reinterpret_cast(a); + tmp_uint += 0x1000u; + tmp_uint &= 0xFFFFE000u; + float ret = reinterpret_cast(tmp_uint); + return ret; +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +} + +} // namespace matrix +} // namespace experimental +} // namespace oneapi +} // namespace ext +} // __SYCL_INLINE_VER_NAMESPACE(_V1) +} // namespace sycl diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix.hpp index 00ae3792626c6..f3e4ab90dc758 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix.hpp @@ -27,5 +27,8 @@ #include #endif // SYCL_EXT_ONEAPI_MATRIX_VERSION #if (SYCL_EXT_ONEAPI_MATRIX_VERSION == 3) -#include +#include +#endif // SYCL_EXT_ONEAPI_MATRIX_VERSION +#if (SYCL_EXT_ONEAPI_MATRIX_VERSION == 4) +#include #endif // SYCL_EXT_ONEAPI_MATRIX_VERSION diff --git a/sycl/test/check_device_code/matrix/matrix-nvptx-bfloat16-test.cpp b/sycl/test/check_device_code/matrix/matrix-nvptx-bfloat16-test.cpp index dc28943d9b102..1310cb1b167f9 100644 --- a/sycl/test/check_device_code/matrix/matrix-nvptx-bfloat16-test.cpp +++ b/sycl/test/check_device_code/matrix/matrix-nvptx-bfloat16-test.cpp @@ -1,7 +1,7 @@ // REQUIRES: cuda -// RUN: %clangxx -Xclang -no-opaque-pointers -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s -// RUN: %clangxx -Xclang -opaque-pointers -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s --check-prefixes=CHECK-OPAQUE +// RUN: %clangxx -Xclang -no-opaque-pointers -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 -S -Xclang -emit-llvm %s -o -| FileCheck %s +// RUN: %clangxx -Xclang -opaque-pointers -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 -S -Xclang -emit-llvm %s -o -| FileCheck %s --check-prefixes=CHECK-OPAQUE #include @@ -39,21 +39,16 @@ int main() { [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { sycl::sub_group sg = item.get_sub_group(); - joint_matrix - sub_c; - - joint_matrix - sub_a; - - joint_matrix - sub_b; + joint_matrix sub_c{}; + joint_matrix + sub_a{}; + joint_matrix + sub_b{}; // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.f32.p1f32(float addrspace(1)* %_arg_accC, i32 16) // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.f32.p1(ptr addrspace(1) %_arg_accC, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, + layout::row_major); // CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.bf16.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16) joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); @@ -65,7 +60,8 @@ int main() { sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1f32(float addrspace(1)* %_arg_accD, float %22, float %23, float %24, float %25, float %26, float %27, float %28, float %29, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) %_arg_accD, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, + layout::row_major); }); cgh.parallel_for( @@ -73,21 +69,16 @@ int main() { [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { sycl::sub_group sg = item.get_sub_group(); - joint_matrix - sub_c; - - joint_matrix - sub_a; - - joint_matrix - sub_b; + joint_matrix sub_c{}; + joint_matrix + sub_a{}; + joint_matrix + sub_b{}; // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.f32.p1f32(float addrspace(1)* %_arg_accC, i32 16) // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.f32.p1(ptr addrspace(1) %_arg_accC, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, + layout::col_major); // CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.a.col.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.a.col.stride.bf16.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16) joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); @@ -99,7 +90,8 @@ int main() { sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1f32(float addrspace(1)* %_arg_accD, float %22, float %23, float %24, float %25, float %26, float %27, float %28, float %29, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) %_arg_accD, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, + layout::col_major); }); cgh.parallel_for( @@ -107,20 +99,16 @@ int main() { [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { sycl::sub_group sg = item.get_sub_group(); - joint_matrix - sub_c; - - joint_matrix - sub_a; - - joint_matrix - sub_b; + joint_matrix sub_c{}; + joint_matrix + sub_a{}; + joint_matrix + sub_b{}; // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.load.c.row.stride.f32.p1f32(float addrspace(1)* %_arg_accC, i32 16) // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.load.c.row.stride.f32.p1(ptr addrspace(1) %_arg_accC, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, + layout::row_major); // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.a.row.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.a.row.stride.bf16.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16) joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); @@ -132,7 +120,8 @@ int main() { sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f32.p1f32(float addrspace(1)* %_arg_accD, float %24, float %25, float %26, float %27, float %28, float %29, float %30, float %31, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f32.p1(ptr addrspace(1) %_arg_accD, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, + layout::row_major); }); cgh.parallel_for( @@ -140,20 +129,16 @@ int main() { [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { sycl::sub_group sg = item.get_sub_group(); - joint_matrix - sub_c; - - joint_matrix - sub_a; - - joint_matrix - sub_b; + joint_matrix sub_c{}; + joint_matrix + sub_a{}; + joint_matrix + sub_b{}; // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.load.c.col.stride.f32.p1f32(float addrspace(1)* %_arg_accC, i32 16) // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.load.c.col.stride.f32.p1(ptr addrspace(1) %_arg_accC, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, + layout::col_major); // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.a.col.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.a.col.stride.bf16.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16) joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); @@ -165,7 +150,8 @@ int main() { sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f32.p1f32(float addrspace(1)* %_arg_accD, float %24, float %25, float %26, float %27, float %28, float %29, float %30, float %31, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f32.p1(ptr addrspace(1) %_arg_accD, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, + layout::col_major); }); cgh.parallel_for( @@ -173,20 +159,16 @@ int main() { [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { sycl::sub_group sg = item.get_sub_group(); - joint_matrix - sub_c; - - joint_matrix - sub_a; - - joint_matrix - sub_b; + joint_matrix sub_c{}; + joint_matrix + sub_a{}; + joint_matrix + sub_b{}; // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.load.c.row.stride.f32.p1f32(float addrspace(1)* %_arg_accC, i32 16) // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.load.c.row.stride.f32.p1(ptr addrspace(1) %_arg_accC, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, + layout::row_major); // CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.a.row.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) // CHECK-OPAQUE: tail call { i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.a.row.stride.bf16.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16) joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); @@ -198,7 +180,8 @@ int main() { sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f32.p1f32(float addrspace(1)* %_arg_accD, float %24, float %25, float %26, float %27, float %28, float %29, float %30, float %31, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f32.p1(ptr addrspace(1) %_arg_accD, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, + layout::row_major); }); cgh.parallel_for( @@ -206,20 +189,16 @@ int main() { [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { sycl::sub_group sg = item.get_sub_group(); - joint_matrix - sub_c; - - joint_matrix - sub_a; - - joint_matrix - sub_b; + joint_matrix sub_c{}; + joint_matrix + sub_a{}; + joint_matrix + sub_b{}; // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.load.c.col.stride.f32.p1f32(float addrspace(1)* %_arg_accC, i32 16) // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.load.c.col.stride.f32.p1(ptr addrspace(1) %_arg_accC, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, + layout::col_major); // CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.a.col.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) // CHECK-OPAQUE: tail call { i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.a.col.stride.bf16.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16) joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); @@ -231,7 +210,8 @@ int main() { sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f32.p1f32(float addrspace(1)* %_arg_accD, float %24, float %25, float %26, float %27, float %28, float %29, float %30, float %31, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f32.p1(ptr addrspace(1) %_arg_accD, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, + layout::col_major); }); }); diff --git a/sycl/test/check_device_code/matrix/matrix-nvptx-double-test.cpp b/sycl/test/check_device_code/matrix/matrix-nvptx-double-test.cpp index fd71ca826f901..65a0a9e944254 100644 --- a/sycl/test/check_device_code/matrix/matrix-nvptx-double-test.cpp +++ b/sycl/test/check_device_code/matrix/matrix-nvptx-double-test.cpp @@ -1,7 +1,7 @@ // REQUIRES: cuda -// RUN: %clangxx -Xclang -no-opaque-pointers -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s -// RUN: %clangxx -Xclang -opaque-pointers -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s --check-prefixes=CHECK-OPAQUE +// RUN: %clangxx -Xclang -no-opaque-pointers -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 -S -Xclang -emit-llvm %s -o -| FileCheck %s +// RUN: %clangxx -Xclang -opaque-pointers -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 -S -Xclang -emit-llvm %s -o -| FileCheck %s --check-prefixes=CHECK-OPAQUE #include @@ -49,19 +49,16 @@ int main() { [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { sycl::sub_group sg = item.get_sub_group(); - joint_matrix - sub_c; - - joint_matrix - sub_a; - - joint_matrix - sub_b; + joint_matrix sub_c{}; + joint_matrix + sub_a{}; + joint_matrix + sub_b{}; //CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.load.c.row.stride.f64.p1f64(double addrspace(1)* %_arg_accC, i32 8) //CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.load.c.row.stride.f64.p1(ptr addrspace(1) %_arg_accC, i32 8) - joint_matrix_load(sg, sub_c, accC.get_pointer(), N); + joint_matrix_load(sg, sub_c, accC.get_pointer(), N, + layout::row_major); //CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.a.row.stride.f64.p1f64(double addrspace(1)* %_arg_accA, i32 4) //CHECK-OPAQUE: tail call double @llvm.nvvm.wmma.m8n8k4.load.a.row.stride.f64.p1(ptr addrspace(1) %_arg_accA, i32 4) joint_matrix_load(sg, sub_a, accA.get_pointer(), K); @@ -73,7 +70,8 @@ int main() { sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); //CHECK: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.row.stride.f64.p1f64(double addrspace(1)* %_arg_accD, double %6, double %7, i32 8) //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.row.stride.f64.p1(ptr addrspace(1) %_arg_accD, double {{.*}}, double {{.*}}, i32 8) - joint_matrix_store(sg, sub_c, accD.get_pointer(), N); + joint_matrix_store(sg, sub_c, accD.get_pointer(), N, + layout::row_major); }); cgh.parallel_for( @@ -81,19 +79,16 @@ int main() { [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { sycl::sub_group sg = item.get_sub_group(); - joint_matrix - sub_c; - - joint_matrix - sub_a; - - joint_matrix - sub_b; + joint_matrix sub_c{}; + joint_matrix + sub_a{}; + joint_matrix + sub_b{}; //CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.load.c.col.stride.f64.p1f64(double addrspace(1)* %_arg_accC, i32 8) //CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.load.c.col.stride.f64.p1(ptr addrspace(1) %_arg_accC, i32 8) - joint_matrix_load(sg, sub_c, accC.get_pointer(), M); + joint_matrix_load(sg, sub_c, accC.get_pointer(), M, + layout::col_major); //CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.a.col.stride.f64.p1f64(double addrspace(1)* %_arg_accA, i32 8) //CHECK-OPAQUE: tail call double @llvm.nvvm.wmma.m8n8k4.load.a.col.stride.f64.p1(ptr addrspace(1) %_arg_accA, i32 8) joint_matrix_load(sg, sub_a, accA.get_pointer(), M); @@ -105,7 +100,8 @@ int main() { sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); //CHECK: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64.p1f64(double addrspace(1)* %_arg_accD, double %6, double %7, i32 8) //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64.p1(ptr addrspace(1) %_arg_accD, double {{.*}}, double {{.*}}, i32 8) - joint_matrix_store(sg, sub_c, accD.get_pointer(), M); + joint_matrix_store(sg, sub_c, accD.get_pointer(), M, + layout::col_major); }); }); diff --git a/sycl/test/check_device_code/matrix/matrix-nvptx-half-float-test.cpp b/sycl/test/check_device_code/matrix/matrix-nvptx-half-float-test.cpp index ed3ba07dd48fa..c3b63668a63bc 100644 --- a/sycl/test/check_device_code/matrix/matrix-nvptx-half-float-test.cpp +++ b/sycl/test/check_device_code/matrix/matrix-nvptx-half-float-test.cpp @@ -1,7 +1,7 @@ // REQUIRES: cuda -// RUN: %clangxx -Xclang -no-opaque-pointers -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_70 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s -// RUN: %clangxx -Xclang -opaque-pointers -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_70 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s --check-prefixes=CHECK-OPAQUE +// RUN: %clangxx -Xclang -no-opaque-pointers -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_70 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 -S -Xclang -emit-llvm %s -o -| FileCheck %s +// RUN: %clangxx -Xclang -opaque-pointers -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_70 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 -S -Xclang -emit-llvm %s -o -| FileCheck %s --check-prefixes=CHECK-OPAQUE #include @@ -38,19 +38,16 @@ int main() { [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { sycl::sub_group sg = item.get_sub_group(); - joint_matrix - sub_c; - - joint_matrix - sub_a; - - joint_matrix - sub_b; + joint_matrix sub_c{}; + joint_matrix + sub_a{}; + joint_matrix + sub_b{}; // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.f32.p1f32(float addrspace(1)* %_arg_accC, i32 16) // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.f32.p1(ptr addrspace(1) %_arg_accC, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, + layout::row_major); // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.f16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.f16.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16) joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); @@ -62,7 +59,8 @@ int main() { sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1f32(float addrspace(1)* %_arg_accD, float %30, float %31, float %32, float %33, float %34, float %35, float %36, float %37, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) %_arg_accD, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, + layout::row_major); }); cgh.parallel_for( @@ -70,19 +68,16 @@ int main() { [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { sycl::sub_group sg = item.get_sub_group(); - joint_matrix - sub_c; - - joint_matrix - sub_a; - - joint_matrix - sub_b; + joint_matrix sub_c{}; + joint_matrix + sub_a{}; + joint_matrix + sub_b{}; // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.f32.p1f32(float addrspace(1)* %_arg_accC, i32 16) // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.f32.p1(ptr addrspace(1) %_arg_accC, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, + layout::col_major); // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.a.col.stride.f16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.a.col.stride.f16.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16) joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); @@ -94,7 +89,8 @@ int main() { sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1f32(float addrspace(1)* %_arg_accD, float %30, float %31, float %32, float %33, float %34, float %35, float %36, float %37, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) %_arg_accD, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, + layout::col_major); }); cgh.parallel_for( @@ -102,19 +98,16 @@ int main() { [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { sycl::sub_group sg = item.get_sub_group(); - joint_matrix - sub_c; - - joint_matrix - sub_a; - - joint_matrix - sub_b; + joint_matrix sub_c{}; + joint_matrix + sub_a{}; + joint_matrix + sub_b{}; // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.load.c.row.stride.f32.p1f32(float addrspace(1)* %_arg_accC, i32 16) // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.load.c.row.stride.f32.p1(ptr addrspace(1) %_arg_accC, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, + layout::row_major); // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.load.a.row.stride.f16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.load.a.row.stride.f16.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16) joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); @@ -126,7 +119,8 @@ int main() { sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f32.p1f32(float addrspace(1)* %_arg_accD, float %30, float %31, float %32, float %33, float %34, float %35, float %36, float %37, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f32.p1(ptr addrspace(1) %_arg_accD, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, + layout::row_major); }); cgh.parallel_for( @@ -134,19 +128,16 @@ int main() { [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { sycl::sub_group sg = item.get_sub_group(); - joint_matrix - sub_c; - - joint_matrix - sub_a; - - joint_matrix - sub_b; + joint_matrix sub_c{}; + joint_matrix + sub_a{}; + joint_matrix + sub_b{}; // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.load.c.col.stride.f32.p1f32(float addrspace(1)* %_arg_accC, i32 16) // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.load.c.col.stride.f32.p1(ptr addrspace(1) %_arg_accC, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, + layout::col_major); // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.load.a.col.stride.f16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.load.a.col.stride.f16.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16) joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); @@ -158,7 +149,8 @@ int main() { sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f32.p1f32(float addrspace(1)* %_arg_accD, float %30, float %31, float %32, float %33, float %34, float %35, float %36, float %37, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f32.p1(ptr addrspace(1) %_arg_accD, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, + layout::col_major); }); cgh.parallel_for( @@ -166,19 +158,16 @@ int main() { [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { sycl::sub_group sg = item.get_sub_group(); - joint_matrix - sub_c; - - joint_matrix - sub_a; - - joint_matrix - sub_b; + joint_matrix sub_c{}; + joint_matrix + sub_a{}; + joint_matrix + sub_b{}; // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.load.c.row.stride.f32.p1f32(float addrspace(1)* %_arg_accC, i32 16) // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.load.c.row.stride.f32.p1(ptr addrspace(1) %_arg_accC, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, + layout::row_major); // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.load.a.row.stride.f16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.load.a.row.stride.f16.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16) joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); @@ -190,7 +179,8 @@ int main() { sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f32.p1f32(float addrspace(1)* %_arg_accD, float %30, float %31, float %32, float %33, float %34, float %35, float %36, float %37, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f32.p1(ptr addrspace(1) %_arg_accD, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, + layout::row_major); }); cgh.parallel_for( @@ -198,19 +188,16 @@ int main() { [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { sycl::sub_group sg = item.get_sub_group(); - joint_matrix - sub_c; - - joint_matrix - sub_a; - - joint_matrix - sub_b; + joint_matrix sub_c{}; + joint_matrix + sub_a{}; + joint_matrix + sub_b{}; // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.load.c.col.stride.f32.p1f32(float addrspace(1)* %_arg_accC, i32 16) // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.load.c.col.stride.f32.p1(ptr addrspace(1) %_arg_accC, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, + layout::col_major); // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.load.a.col.stride.f16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.load.a.col.stride.f16.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16) joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); @@ -222,7 +209,8 @@ int main() { sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f32.p1f32(float addrspace(1)* %_arg_accD, float %30, float %31, float %32, float %33, float %34, float %35, float %36, float %37, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f32.p1(ptr addrspace(1) %_arg_accD, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, + layout::col_major); }); }); diff --git a/sycl/test/check_device_code/matrix/matrix-nvptx-half-half-test.cpp b/sycl/test/check_device_code/matrix/matrix-nvptx-half-half-test.cpp index 602abad05830f..d38fa9c2aee5d 100644 --- a/sycl/test/check_device_code/matrix/matrix-nvptx-half-half-test.cpp +++ b/sycl/test/check_device_code/matrix/matrix-nvptx-half-half-test.cpp @@ -1,7 +1,7 @@ // REQUIRES: cuda -// RUN: %clangxx -Xclang -no-opaque-pointers -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_70 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s -// RUN: %clangxx -Xclang -opaque-pointers -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_70 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s --check-prefixes=CHECK-OPAQUE +// RUN: %clangxx -Xclang -no-opaque-pointers -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_70 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 -S -Xclang -emit-llvm %s -o -| FileCheck %s +// RUN: %clangxx -Xclang -opaque-pointers -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_70 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 -S -Xclang -emit-llvm %s -o -| FileCheck %s --check-prefixes=CHECK-OPAQUE #include @@ -38,19 +38,16 @@ int main() { [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { sycl::sub_group sg = item.get_sub_group(); - joint_matrix - sub_c; - - joint_matrix - sub_a; - - joint_matrix - sub_b; + joint_matrix sub_c{}; + joint_matrix + sub_a{}; + joint_matrix + sub_b{}; // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.f16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.f16.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, + layout::row_major); // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.f16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.f16.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16) joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); @@ -62,7 +59,8 @@ int main() { sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f16.p0i32(i32* %call.ascast.i.i{{.*}}.i, <2 x half> %27, <2 x half> %28, <2 x half> %29, <2 x half> %30, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f16.p0(ptr %call.ascast.i.i{{.*}}.i, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, + layout::row_major); }); cgh.parallel_for( @@ -70,19 +68,16 @@ int main() { [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { sycl::sub_group sg = item.get_sub_group(); - joint_matrix - sub_c; - - joint_matrix - sub_a; + joint_matrix sub_c{}; + joint_matrix + sub_a{}; + joint_matrix + sub_b{}; - joint_matrix - sub_b; - - // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.f16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) - // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.f16.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.f16.p0i32(i32* %call.ascast.i{{.*}}.i.i{{.*}}.i, i32 16) + // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.f16.p0(ptr %call.ascast.i{{.*}}.i.i{{.*}}.i, i32 16) + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, + layout::col_major); // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.a.col.stride.f16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.a.col.stride.f16.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16) joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); @@ -92,9 +87,10 @@ int main() { // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.mma.col.col.f16.f16(<2 x half> %8, <2 x half> %9, <2 x half> %10, <2 x half> %11, <2 x half> %12, <2 x half> %13, <2 x half> %14, <2 x half> %15, <2 x half> %18, <2 x half> %19, <2 x half> %20, <2 x half> %21, <2 x half> %22, <2 x half> %23, <2 x half> %24, <2 x half> %25, <2 x half> %2, <2 x half> %3, <2 x half> %4, <2 x half> %5) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); - // CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f16.p0i32(i32* %call.ascast.i.i{{.*}}.i, <2 x half> %27, <2 x half> %28, <2 x half> %29, <2 x half> %30, i32 16) - // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f16.p0(ptr %call.ascast.i.i{{.*}}.i, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + // CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f16.p0i32(i32* %call.ascast.i{{.*}}.i.i{{.*}}.i, <2 x half> %27, <2 x half> %28, <2 x half> %29, <2 x half> %30, i32 16) + // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f16.p0(ptr %call.ascast.i{{.*}}.i.i{{.*}}.i, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, + layout::col_major); }); cgh.parallel_for( @@ -102,19 +98,16 @@ int main() { [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { sycl::sub_group sg = item.get_sub_group(); - joint_matrix - sub_c; - - joint_matrix - sub_a; - - joint_matrix - sub_b; + joint_matrix sub_c{}; + joint_matrix + sub_a{}; + joint_matrix + sub_b{}; // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.load.c.row.stride.f16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.load.c.row.stride.f16.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, + layout::row_major); // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.load.a.row.stride.f16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.load.a.row.stride.f16.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16) joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); @@ -126,7 +119,8 @@ int main() { sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f16.p0i32(i32* %call.ascast.i.i{{.*}}.i, <2 x half> %27, <2 x half> %28, <2 x half> %29, <2 x half> %30, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f16.p0(ptr %call.ascast.i.i{{.*}}.i, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, + layout::row_major); }); cgh.parallel_for( @@ -134,19 +128,16 @@ int main() { [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { sycl::sub_group sg = item.get_sub_group(); - joint_matrix - sub_c; + joint_matrix sub_c{}; + joint_matrix + sub_a{}; + joint_matrix + sub_b{}; - joint_matrix - sub_a; - - joint_matrix - sub_b; - - // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.load.c.col.stride.f16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) - // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.load.c.col.stride.f16.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.load.c.col.stride.f16.p0i32(i32* %call.ascast.i{{.*}}.i.i{{.*}}.i, i32 16) + // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.load.c.col.stride.f16.p0(ptr %call.ascast.i{{.*}}.i.i{{.*}}.i, i32 16) + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, + layout::col_major); // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.load.a.col.stride.f16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.load.a.col.stride.f16.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16) joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); @@ -156,9 +147,10 @@ int main() { // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.mma.col.col.f16.f16(<2 x half> %8, <2 x half> %9, <2 x half> %10, <2 x half> %11, <2 x half> %12, <2 x half> %13, <2 x half> %14, <2 x half> %15, <2 x half> %18, <2 x half> %19, <2 x half> %20, <2 x half> %21, <2 x half> %22, <2 x half> %23, <2 x half> %24, <2 x half> %25, <2 x half> %2, <2 x half> %3, <2 x half> %4, <2 x half> %5) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); - // CHECK: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f16.p0i32(i32* %call.ascast.i.i{{.*}}.i, <2 x half> %27, <2 x half> %28, <2 x half> %29, <2 x half> %30, i32 16) - // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f16.p0(ptr %call.ascast.i.i{{.*}}.i, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + // CHECK: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f16.p0i32(i32* %call.ascast.i{{.*}}.i.i{{.*}}.i, <2 x half> %27, <2 x half> %28, <2 x half> %29, <2 x half> %30, i32 16) + // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f16.p0(ptr %call.ascast.i{{.*}}.i.i{{.*}}.i, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, + layout::col_major); }); cgh.parallel_for( @@ -166,19 +158,16 @@ int main() { [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { sycl::sub_group sg = item.get_sub_group(); - joint_matrix - sub_c; - - joint_matrix - sub_a; - - joint_matrix - sub_b; + joint_matrix sub_c{}; + joint_matrix + sub_a{}; + joint_matrix + sub_b{}; // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.load.c.row.stride.f16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.load.c.row.stride.f16.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, + layout::row_major); // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.load.a.row.stride.f16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.load.a.row.stride.f16.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16) joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); @@ -190,7 +179,8 @@ int main() { sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f16.p0i32(i32* %call.ascast.i.i{{.*}}.i, <2 x half> %27, <2 x half> %28, <2 x half> %29, <2 x half> %30, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f16.p0(ptr %call.ascast.i.i{{.*}}.i, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, + layout::row_major); }); cgh.parallel_for( @@ -198,19 +188,16 @@ int main() { [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { sycl::sub_group sg = item.get_sub_group(); - joint_matrix - sub_c; - - joint_matrix - sub_a; - - joint_matrix - sub_b; + joint_matrix sub_c{}; + joint_matrix + sub_a{}; + joint_matrix + sub_b{}; - // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.load.c.col.stride.f16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) - // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.load.c.col.stride.f16.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.load.c.col.stride.f16.p0i32(i32* %call.ascast.i{{.*}}.i.i{{.*}}.i, i32 16) + // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.load.c.col.stride.f16.p0(ptr %call.ascast.i{{.*}}.i.i{{.*}}.i, i32 16) + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, + layout::col_major); // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.load.a.col.stride.f16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.load.a.col.stride.f16.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16) joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); @@ -220,9 +207,10 @@ int main() { // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.mma.col.col.f16.f16(<2 x half> %8, <2 x half> %9, <2 x half> %10, <2 x half> %11, <2 x half> %12, <2 x half> %13, <2 x half> %14, <2 x half> %15, <2 x half> %18, <2 x half> %19, <2 x half> %20, <2 x half> %21, <2 x half> %22, <2 x half> %23, <2 x half> %24, <2 x half> %25, <2 x half> %2, <2 x half> %3, <2 x half> %4, <2 x half> %5) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); - // CHECK: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f16.p0i32(i32* %call.ascast.i.i{{.*}}.i, <2 x half> %27, <2 x half> %28, <2 x half> %29, <2 x half> %30, i32 16) - // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f16.p0(ptr %call.ascast.i.i{{.*}}.i, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + // CHECK: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f16.p0i32(i32* %call.ascast.i{{.*}}.i.i{{.*}}.i, <2 x half> %27, <2 x half> %28, <2 x half> %29, <2 x half> %30, i32 16) + // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f16.p0(ptr %call.ascast.i{{.*}}.i.i{{.*}}.i, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, + layout::col_major); }); }); diff --git a/sycl/test/check_device_code/matrix/matrix-nvptx-int8-test.cpp b/sycl/test/check_device_code/matrix/matrix-nvptx-int8-test.cpp index 08045cf688485..ed09b68cf6797 100644 --- a/sycl/test/check_device_code/matrix/matrix-nvptx-int8-test.cpp +++ b/sycl/test/check_device_code/matrix/matrix-nvptx-int8-test.cpp @@ -1,7 +1,7 @@ // REQUIRES: cuda -// RUN: %clangxx -Xclang -no-opaque-pointers -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_72 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s -// RUN: %clangxx -Xclang -opaque-pointers -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_72 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s --check-prefixes=CHECK-OPAQUE +// RUN: %clangxx -Xclang -no-opaque-pointers -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_72 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 -S -Xclang -emit-llvm %s -o -| FileCheck %s +// RUN: %clangxx -Xclang -opaque-pointers -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_72 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 -S -Xclang -emit-llvm %s -o -| FileCheck %s --check-prefixes=CHECK-OPAQUE #include @@ -38,19 +38,16 @@ int main() { [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { sycl::sub_group sg = item.get_sub_group(); - joint_matrix - sub_c; - - joint_matrix - sub_a; - - joint_matrix - sub_b; + joint_matrix sub_c{}; + joint_matrix + sub_a{}; + joint_matrix + sub_b{}; // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.s32.p1i32(i32 addrspace(1)* %_arg_accC, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.s32.p1(ptr addrspace(1) %_arg_accC, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, + layout::row_major); // CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.s8.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) // CHECK-OPAQUE: tail call { i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.s8.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16) joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); @@ -62,7 +59,8 @@ int main() { sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.s32.p1i32(i32 addrspace(1)* %_arg_accD, i32 %18, i32 %19, i32 %20, i32 %21, i32 %22, i32 %23, i32 %24, i32 %25, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.s32.p1(ptr addrspace(1) %_arg_accD, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, + layout::row_major); }); cgh.parallel_for( @@ -70,19 +68,16 @@ int main() { [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { sycl::sub_group sg = item.get_sub_group(); - joint_matrix - sub_c; - - joint_matrix - sub_a; - - joint_matrix - sub_b; + joint_matrix sub_c{}; + joint_matrix + sub_a{}; + joint_matrix + sub_b{}; // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.s32.p1i32(i32 addrspace(1)* %_arg_accC, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.s32.p1(ptr addrspace(1) %_arg_accC, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, + layout::col_major); // CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.a.col.stride.s8.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) // CHECK-OPAQUE: tail call { i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.a.col.stride.s8.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16) joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); @@ -94,7 +89,8 @@ int main() { sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.s32.p1i32(i32 addrspace(1)* %_arg_accD, i32 %18, i32 %19, i32 %20, i32 %21, i32 %22, i32 %23, i32 %24, i32 %25, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.s32.p1(ptr addrspace(1) %_arg_accD, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, + layout::col_major); }); cgh.parallel_for( @@ -102,19 +98,16 @@ int main() { [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { sycl::sub_group sg = item.get_sub_group(); - joint_matrix - sub_c; - - joint_matrix - sub_a; - - joint_matrix - sub_b; + joint_matrix sub_c{}; + joint_matrix + sub_a{}; + joint_matrix + sub_b{}; // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.c.row.stride.s32.p1i32(i32 addrspace(1)* %_arg_accC, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.c.row.stride.s32.p1(ptr addrspace(1) %_arg_accC, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, + layout::row_major); // CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.a.row.stride.s8.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.a.row.stride.s8.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16) joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); @@ -126,7 +119,8 @@ int main() { sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.s32.p1i32(i32 addrspace(1)* %_arg_accD, i32 %18, i32 %19, i32 %20, i32 %21, i32 %22, i32 %23, i32 %24, i32 %25, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.s32.p1(ptr addrspace(1) %_arg_accD, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, + layout::row_major); }); cgh.parallel_for( @@ -134,19 +128,16 @@ int main() { [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { sycl::sub_group sg = item.get_sub_group(); - joint_matrix - sub_c; - - joint_matrix - sub_a; - - joint_matrix - sub_b; + joint_matrix sub_c{}; + joint_matrix + sub_a{}; + joint_matrix + sub_b{}; // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.c.col.stride.s32.p1i32(i32 addrspace(1)* %_arg_accC, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.c.col.stride.s32.p1(ptr addrspace(1) %_arg_accC, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, + layout::col_major); // CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.a.col.stride.s8.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.a.col.stride.s8.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16) joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); @@ -158,7 +149,8 @@ int main() { sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.s32.p1i32(i32 addrspace(1)* %_arg_accD, i32 %18, i32 %19, i32 %20, i32 %21, i32 %22, i32 %23, i32 %24, i32 %25, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.s32.p1(ptr addrspace(1) %_arg_accD, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, + layout::col_major); }); cgh.parallel_for( @@ -166,19 +158,16 @@ int main() { [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { sycl::sub_group sg = item.get_sub_group(); - joint_matrix - sub_c; - - joint_matrix - sub_a; - - joint_matrix - sub_b; + joint_matrix sub_c{}; + joint_matrix + sub_a{}; + joint_matrix + sub_b{}; // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.c.row.stride.s32.p1i32(i32 addrspace(1)* %_arg_accC, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.c.row.stride.s32.p1(ptr addrspace(1) %_arg_accC, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, + layout::row_major); // CHECK: tail call i32 @llvm.nvvm.wmma.m8n32k16.load.a.row.stride.s8.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) // CHECK-OPAQUE: tail call i32 @llvm.nvvm.wmma.m8n32k16.load.a.row.stride.s8.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16) joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); @@ -190,7 +179,8 @@ int main() { sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.s32.p1i32(i32 addrspace(1)* %_arg_accD, i32 %18, i32 %19, i32 %20, i32 %21, i32 %22, i32 %23, i32 %24, i32 %25, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.s32.p1(ptr addrspace(1) %_arg_accD, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, + layout::row_major); }); cgh.parallel_for( @@ -198,19 +188,16 @@ int main() { [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { sycl::sub_group sg = item.get_sub_group(); - joint_matrix - sub_c; - - joint_matrix - sub_a; - - joint_matrix - sub_b; + joint_matrix sub_c{}; + joint_matrix + sub_a{}; + joint_matrix + sub_b{}; // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.c.col.stride.s32.p1i32(i32 addrspace(1)* %_arg_accC, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.c.col.stride.s32.p1(ptr addrspace(1) %_arg_accC, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, + layout::col_major); // CHECK: tail call i32 @llvm.nvvm.wmma.m8n32k16.load.a.col.stride.s8.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) // CHECK-OPAQUE: tail call i32 @llvm.nvvm.wmma.m8n32k16.load.a.col.stride.s8.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16) joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); @@ -222,7 +209,8 @@ int main() { sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.s32.p1i32(i32 addrspace(1)* %_arg_accD, i32 %18, i32 %19, i32 %20, i32 %21, i32 %22, i32 %23, i32 %24, i32 %25, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.s32.p1(ptr addrspace(1) %_arg_accD, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, + layout::col_major); }); }); diff --git a/sycl/test/check_device_code/matrix/matrix-nvptx-tf32-test.cpp b/sycl/test/check_device_code/matrix/matrix-nvptx-tf32-test.cpp index 46038cc30eebf..fc5c55045b9f2 100644 --- a/sycl/test/check_device_code/matrix/matrix-nvptx-tf32-test.cpp +++ b/sycl/test/check_device_code/matrix/matrix-nvptx-tf32-test.cpp @@ -1,7 +1,7 @@ // REQUIRES: cuda -// RUN: %clangxx -Xclang -no-opaque-pointers -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s -// RUN: %clangxx -Xclang -opaque-pointers -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s --check-prefixes=CHECK-OPAQUE +// RUN: %clangxx -Xclang -no-opaque-pointers -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 -S -Xclang -emit-llvm %s -o -| FileCheck %s +// RUN: %clangxx -Xclang -opaque-pointers -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 -S -Xclang -emit-llvm %s -o -| FileCheck %s --check-prefixes=CHECK-OPAQUE // IMPORTANT: before updating sm version support beyond sm_90 read the following // NOTE! @@ -59,17 +59,13 @@ int main() { [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { sycl::sub_group sg = item.get_sub_group(); - joint_matrix - sub_a; - - joint_matrix - sub_b; - - joint_matrix - sub_c; + joint_matrix + sub_a{}; + joint_matrix + sub_b{}; + joint_matrix sub_c{}; //CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k8.load.a.row.stride.tf32.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 8) //CHECK-OPAQUE: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k8.load.a.row.stride.tf32.p0(ptr %call.ascast.i.i{{.*}}.i, i32 8) @@ -79,21 +75,25 @@ int main() { joint_matrix_load(sg, sub_b, accB.get_pointer(), N); //CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.f32.p1f32(float addrspace(1)* %_arg_accC, i32 16) //CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.f32.p1(ptr addrspace(1) %_arg_accC, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), N); + joint_matrix_load(sg, sub_c, accC.get_pointer(), N, + layout::row_major); // CHECK: tail call i32 @llvm.nvvm.f2tf32.rna(float {{.*}} // Round a, b to tf32 for (auto i = 0; i < 4; ++i) - sub_a.wi_marray[i] = round_to_tf32(sub_a.wi_marray[i]); + get_wi_data(sg, sub_a)[i] = + round_to_tf32(get_wi_data(sg, sub_a)[i]); for (auto i = 0; i < 4; ++i) - sub_b.wi_marray[i] = round_to_tf32(sub_b.wi_marray[i]); + get_wi_data(sg, sub_b)[i] = + round_to_tf32(get_wi_data(sg, sub_b)[i]); //CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k8.mma.row.row.tf32(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 %{{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); //CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1f32(float addrspace(1)* {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 {{.*}} //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 {{.*}} - joint_matrix_store(sg, sub_c, accD.get_pointer(), N); + joint_matrix_store(sg, sub_c, accD.get_pointer(), N, + layout::row_major); }); }); @@ -108,17 +108,13 @@ int main() { [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { sycl::sub_group sg = item.get_sub_group(); - joint_matrix - sub_a; - - joint_matrix - sub_b; - - joint_matrix - sub_c; + joint_matrix + sub_a{}; + joint_matrix + sub_b{}; + joint_matrix sub_c{}; //CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k8.load.a.col.stride.tf32.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 8) //CHECK-OPAQUE: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k8.load.a.col.stride.tf32.p0(ptr %call.ascast.i.i{{.*}}.i, i32 8) @@ -128,21 +124,25 @@ int main() { joint_matrix_load(sg, sub_b, accB.get_pointer(), N); //CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.f32.p1f32(float addrspace(1)* {{.*}}, i32 {{.*}}) //CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.f32.p1(ptr addrspace(1) {{.*}}, i32 {{.*}}) - joint_matrix_load(sg, sub_c, accC.get_pointer(), N); + joint_matrix_load(sg, sub_c, accC.get_pointer(), N, + layout::col_major); // CHECK: tail call i32 @llvm.nvvm.f2tf32.rna(float {{.*}} // Round a, b to tf32 for (auto i = 0; i < 4; ++i) - sub_a.wi_marray[i] = round_to_tf32(sub_a.wi_marray[i]); + get_wi_data(sg, sub_a)[i] = + round_to_tf32(get_wi_data(sg, sub_a)[i]); for (auto i = 0; i < 4; ++i) - sub_b.wi_marray[i] = round_to_tf32(sub_b.wi_marray[i]); + get_wi_data(sg, sub_b)[i] = + round_to_tf32(get_wi_data(sg, sub_b)[i]); //CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k8.mma.col.col.tf32(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); //CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1f32(float addrspace(1)* {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), N); + joint_matrix_store(sg, sub_c, accD.get_pointer(), N, + layout::col_major); }); }); diff --git a/sycl/test/check_device_code/matrix/matrix-nvptx-uint8-test.cpp b/sycl/test/check_device_code/matrix/matrix-nvptx-uint8-test.cpp index 4c234d5ffba37..475fd44eaed72 100644 --- a/sycl/test/check_device_code/matrix/matrix-nvptx-uint8-test.cpp +++ b/sycl/test/check_device_code/matrix/matrix-nvptx-uint8-test.cpp @@ -1,7 +1,7 @@ // REQUIRES: cuda -// RUN: %clangxx -Xclang -no-opaque-pointers -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_72 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s -// RUN: %clangxx -Xclang -opaque-pointers -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_72 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s --check-prefixes=CHECK-OPAQUE +// RUN: %clangxx -Xclang -no-opaque-pointers -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_72 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 -S -Xclang -emit-llvm %s -o -| FileCheck %s +// RUN: %clangxx -Xclang -opaque-pointers -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_72 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 -S -Xclang -emit-llvm %s -o -| FileCheck %s --check-prefixes=CHECK-OPAQUE #include @@ -38,19 +38,16 @@ int main() { [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { sycl::sub_group sg = item.get_sub_group(); - joint_matrix - sub_c; - - joint_matrix - sub_a; - - joint_matrix - sub_b; + joint_matrix sub_c{}; + joint_matrix + sub_a{}; + joint_matrix + sub_b{}; // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.s32.p1i32(i32 addrspace(1)* %_arg_accC, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.s32.p1(ptr addrspace(1) %_arg_accC, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, + layout::row_major); // CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.u8.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) // CHECK-OPAQUE: tail call { i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.u8.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16) joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); @@ -62,7 +59,8 @@ int main() { sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.s32.p1i32(i32 addrspace(1)* %_arg_accD, i32 %18, i32 %19, i32 %20, i32 %21, i32 %22, i32 %23, i32 %24, i32 %25, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.s32.p1(ptr addrspace(1) %_arg_accD, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, + layout::row_major); }); cgh.parallel_for( @@ -70,19 +68,16 @@ int main() { [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { sycl::sub_group sg = item.get_sub_group(); - joint_matrix - sub_c; - - joint_matrix - sub_a; - - joint_matrix - sub_b; + joint_matrix sub_c{}; + joint_matrix + sub_a{}; + joint_matrix + sub_b{}; // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.s32.p1i32(i32 addrspace(1)* %_arg_accC, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.s32.p1(ptr addrspace(1) %_arg_accC, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, + layout::col_major); // CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.a.col.stride.u8.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) // CHECK-OPAQUE: tail call { i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.a.col.stride.u8.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16) joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); @@ -94,7 +89,8 @@ int main() { sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.s32.p1i32(i32 addrspace(1)* %_arg_accD, i32 %18, i32 %19, i32 %20, i32 %21, i32 %22, i32 %23, i32 %24, i32 %25, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.s32.p1(ptr addrspace(1) %_arg_accD, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, + layout::col_major); }); cgh.parallel_for( @@ -102,19 +98,16 @@ int main() { [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { sycl::sub_group sg = item.get_sub_group(); - joint_matrix - sub_c; - - joint_matrix - sub_a; - - joint_matrix - sub_b; + joint_matrix sub_c{}; + joint_matrix + sub_a{}; + joint_matrix + sub_b{}; // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.c.row.stride.s32.p1i32(i32 addrspace(1)* %_arg_accC, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.c.row.stride.s32.p1(ptr addrspace(1) %_arg_accC, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, + layout::row_major); // CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.a.row.stride.u8.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.a.row.stride.u8.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16) joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); @@ -126,7 +119,8 @@ int main() { sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.s32.p1i32(i32 addrspace(1)* %_arg_accD, i32 %18, i32 %19, i32 %20, i32 %21, i32 %22, i32 %23, i32 %24, i32 %25, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.s32.p1(ptr addrspace(1) %_arg_accD, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, + layout::row_major); }); cgh.parallel_for( @@ -134,19 +128,16 @@ int main() { [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { sycl::sub_group sg = item.get_sub_group(); - joint_matrix - sub_c; - - joint_matrix - sub_a; - - joint_matrix - sub_b; + joint_matrix sub_c{}; + joint_matrix + sub_a{}; + joint_matrix + sub_b{}; // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.c.col.stride.s32.p1i32(i32 addrspace(1)* %_arg_accC, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.c.col.stride.s32.p1(ptr addrspace(1) %_arg_accC, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, + layout::col_major); // CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.a.col.stride.u8.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.a.col.stride.u8.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16) joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); @@ -158,7 +149,8 @@ int main() { sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.s32.p1i32(i32 addrspace(1)* %_arg_accD, i32 %18, i32 %19, i32 %20, i32 %21, i32 %22, i32 %23, i32 %24, i32 %25, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.s32.p1(ptr addrspace(1) %_arg_accD, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, + layout::col_major); }); cgh.parallel_for( @@ -166,19 +158,16 @@ int main() { [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { sycl::sub_group sg = item.get_sub_group(); - joint_matrix - sub_c; - - joint_matrix - sub_a; - - joint_matrix - sub_b; + joint_matrix sub_c{}; + joint_matrix + sub_a{}; + joint_matrix + sub_b{}; // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.c.row.stride.s32.p1i32(i32 addrspace(1)* %_arg_accC, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.c.row.stride.s32.p1(ptr addrspace(1) %_arg_accC, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, + layout::row_major); // CHECK: tail call i32 @llvm.nvvm.wmma.m8n32k16.load.a.row.stride.u8.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) // CHECK-OPAQUE: tail call i32 @llvm.nvvm.wmma.m8n32k16.load.a.row.stride.u8.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16) joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); @@ -190,7 +179,8 @@ int main() { sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.s32.p1i32(i32 addrspace(1)* %_arg_accD, i32 %18, i32 %19, i32 %20, i32 %21, i32 %22, i32 %23, i32 %24, i32 %25, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.s32.p1(ptr addrspace(1) %_arg_accD, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, + layout::row_major); }); cgh.parallel_for( @@ -198,19 +188,16 @@ int main() { [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { sycl::sub_group sg = item.get_sub_group(); - joint_matrix - sub_c; - - joint_matrix - sub_a; - - joint_matrix - sub_b; + joint_matrix sub_c{}; + joint_matrix + sub_a{}; + joint_matrix + sub_b{}; // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.c.col.stride.s32.p1i32(i32 addrspace(1)* %_arg_accC, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.c.col.stride.s32.p1(ptr addrspace(1) %_arg_accC, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, + layout::col_major); // CHECK: tail call i32 @llvm.nvvm.wmma.m8n32k16.load.a.col.stride.u8.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) // CHECK-OPAQUE: tail call i32 @llvm.nvvm.wmma.m8n32k16.load.a.col.stride.u8.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16) joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); @@ -222,7 +209,8 @@ int main() { sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.s32.p1i32(i32 addrspace(1)* %_arg_accD, i32 %18, i32 %19, i32 %20, i32 %21, i32 %22, i32 %23, i32 %24, i32 %25, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.s32.p1(ptr addrspace(1) %_arg_accD, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, + layout::col_major); }); });