diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp index fc240e4fa1818..b90b91abe65e1 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp @@ -218,10 +218,11 @@ struct joint_matrix_load_impl< void load(sycl::ext::oneapi::experimental::matrix::joint_matrix< S, Use, NumRows, NumCols, Layout, sycl::sub_group> &res, multi_ptr src, size_t stride) { - if constexpr (std::is_same::value || + if constexpr (std::is_same, uint16_t>::value || std::is_same< - T, sycl::ext::oneapi::experimental::bfloat16>::value) { - auto tileptr = reinterpret_cast(src.get()); + std::remove_const_t, + sycl::ext::oneapi::experimental::bfloat16>::value) { + auto tileptr = reinterpret_cast(src.get()); auto destptr = reinterpret_cast(&res.wi_marray); if constexpr (NumRows == 16 && NumCols == 16) { if constexpr (Use == @@ -246,8 +247,8 @@ struct joint_matrix_load_impl< __mma_bf16_m32n8k16_ld_b(destptr, tileptr, stride, get_layout_id()); } - } else if constexpr (std::is_same::value) { - auto tileptr = reinterpret_cast(src.get()); + } else if constexpr (std::is_same, uint8_t>::value) { + auto tileptr = reinterpret_cast(src.get()); auto destptr = reinterpret_cast(&res.wi_marray); if constexpr (NumRows == 16 && NumCols == 16) { if constexpr (Use == @@ -272,8 +273,8 @@ struct joint_matrix_load_impl< __imma_m32n8k16_ld_b_u8(destptr, tileptr, stride, get_layout_id()); } - } else if constexpr (std::is_same::value) { - auto tileptr = reinterpret_cast(src.get()); + } else if constexpr (std::is_same, int8_t>::value) { + auto tileptr = reinterpret_cast(src.get()); auto destptr = reinterpret_cast(&res.wi_marray); if constexpr (NumRows == 16 && NumCols == 16) { if constexpr (Use == @@ -298,8 +299,8 @@ struct joint_matrix_load_impl< __imma_m32n8k16_ld_b_s8(destptr, tileptr, stride, get_layout_id()); } - } else if constexpr (std::is_same::value) { - auto tileptr = reinterpret_cast(src.get()); + } else if constexpr (std::is_same, half>::value) { + auto tileptr = reinterpret_cast(src.get()); auto dstptr = reinterpret_cast(&res.wi_marray); if constexpr (NumRows == 16 && NumCols == 16) { if constexpr (Use == @@ -331,7 +332,7 @@ struct joint_matrix_load_impl< get_layout_id()); } - } else if constexpr (std::is_same::value) { + } else if constexpr (std::is_same, int32_t>::value) { auto destptr = reinterpret_cast(&res.wi_marray); if constexpr (NumRows == 16 && NumCols == 16) { __imma_m16n16k16_ld_c(destptr, src.get(), stride, @@ -343,7 +344,7 @@ struct joint_matrix_load_impl< __imma_m32n8k16_ld_c(destptr, src.get(), stride, get_layout_id()); } - } else if constexpr (std::is_same::value) { + } else if constexpr (std::is_same, float>::value) { if constexpr (std::is_same::value) { auto dstptr = reinterpret_cast(&res.wi_marray); if constexpr (NumRows == 16 && NumCols == 16) { @@ -359,7 +360,7 @@ struct joint_matrix_load_impl< } else if constexpr (std::is_same::value) { - auto tileptr = reinterpret_cast(src.get()); + auto tileptr = reinterpret_cast(src.get()); auto dstptr = reinterpret_cast(&res.wi_marray); if constexpr (NumRows == 16 && NumCols == 8) { __mma_tf32_m16n16k8_ld_a(dstptr, tileptr, stride, @@ -369,7 +370,7 @@ struct joint_matrix_load_impl< get_layout_id()); } } - } else if constexpr (std::is_same::value) { + } else if constexpr (std::is_same, double>::value) { auto dstptr = reinterpret_cast(&res.wi_marray); if constexpr (Use == sycl::ext::oneapi::experimental::matrix::matrix_use::a) { @@ -559,9 +560,9 @@ struct joint_matrix_mad_impl< D; if constexpr (M == 16 && N == 16 && K == 16) { if constexpr (std::is_same::value) { - auto ptrA = reinterpret_cast(&A.wi_marray); - auto ptrB = reinterpret_cast(&B.wi_marray); - auto ptrC = reinterpret_cast(&C.wi_marray); + auto ptrA = reinterpret_cast(&A.wi_marray); + auto ptrB = reinterpret_cast(&B.wi_marray); + auto ptrC = reinterpret_cast(&C.wi_marray); auto ptrD = reinterpret_cast(&D.wi_marray); if constexpr (std::is_same::value) { __imma_m16n16k16_mma_s8(ptrD, ptrA, ptrB, ptrC, @@ -571,17 +572,17 @@ struct joint_matrix_mad_impl< get_layout_pair_id(), 0); } } else if constexpr (std::is_same::value) { - auto ptrA = reinterpret_cast(&A.wi_marray); - auto ptrB = reinterpret_cast(&B.wi_marray); + auto ptrA = reinterpret_cast(&A.wi_marray); + auto ptrB = reinterpret_cast(&B.wi_marray); if constexpr (std::is_same::value) { __hmma_m16n16k16_mma_f32f32( reinterpret_cast(&D.wi_marray), ptrA, ptrB, - reinterpret_cast(&C.wi_marray), + reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); } else if constexpr (std::is_same::value) { __hmma_m16n16k16_mma_f16f16( reinterpret_cast(&D.wi_marray), ptrA, ptrB, - reinterpret_cast(&C.wi_marray), + reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); } } else if constexpr (std::is_same::value || @@ -589,16 +590,16 @@ struct joint_matrix_mad_impl< bfloat16>::value) { __mma_bf16_m16n16k16_mma_f32( reinterpret_cast(&D.wi_marray), - reinterpret_cast(&A.wi_marray), - reinterpret_cast(&B.wi_marray), - reinterpret_cast(&C.wi_marray), + reinterpret_cast(&A.wi_marray), + reinterpret_cast(&B.wi_marray), + reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); } } else if constexpr (M == 8 && N == 32 && K == 16) { if constexpr (std::is_same::value) { - auto ptrA = reinterpret_cast(&A.wi_marray); - auto ptrB = reinterpret_cast(&B.wi_marray); - auto ptrC = reinterpret_cast(&C.wi_marray); + auto ptrA = reinterpret_cast(&A.wi_marray); + auto ptrB = reinterpret_cast(&B.wi_marray); + auto ptrC = reinterpret_cast(&C.wi_marray); auto ptrD = reinterpret_cast(&D.wi_marray); if constexpr (std::is_same::value) { __imma_m8n32k16_mma_s8(ptrD, ptrA, ptrB, ptrC, @@ -608,17 +609,17 @@ struct joint_matrix_mad_impl< get_layout_pair_id(), 0); } } else if constexpr (std::is_same::value) { - auto ptrA = reinterpret_cast(&A.wi_marray); - auto ptrB = reinterpret_cast(&B.wi_marray); + auto ptrA = reinterpret_cast(&A.wi_marray); + auto ptrB = reinterpret_cast(&B.wi_marray); if constexpr (std::is_same::value) { __hmma_m8n32k16_mma_f32f32( reinterpret_cast(&D.wi_marray), ptrA, ptrB, - reinterpret_cast(&C.wi_marray), + reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); } else if constexpr (std::is_same::value) { __hmma_m8n32k16_mma_f16f16( reinterpret_cast(&D.wi_marray), ptrA, ptrB, - reinterpret_cast(&C.wi_marray), + reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); } } else if constexpr (std::is_same::value || @@ -626,16 +627,16 @@ struct joint_matrix_mad_impl< bfloat16>::value) { __mma_bf16_m8n32k16_mma_f32( reinterpret_cast(&D.wi_marray), - reinterpret_cast(&A.wi_marray), - reinterpret_cast(&B.wi_marray), - reinterpret_cast(&C.wi_marray), + reinterpret_cast(&A.wi_marray), + reinterpret_cast(&B.wi_marray), + reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); } } else if constexpr (M == 32 && N == 8 && K == 16) { if constexpr (std::is_same::value) { - auto ptrA = reinterpret_cast(&A.wi_marray); - auto ptrB = reinterpret_cast(&B.wi_marray); - auto ptrC = reinterpret_cast(&C.wi_marray); + auto ptrA = reinterpret_cast(&A.wi_marray); + auto ptrB = reinterpret_cast(&B.wi_marray); + auto ptrC = reinterpret_cast(&C.wi_marray); auto ptrD = reinterpret_cast(&D.wi_marray); if constexpr (std::is_same::value) { __imma_m32n8k16_mma_s8(ptrD, ptrA, ptrB, ptrC, @@ -649,22 +650,22 @@ struct joint_matrix_mad_impl< bfloat16>::value) { __mma_bf16_m32n8k16_mma_f32( reinterpret_cast(&D.wi_marray), - reinterpret_cast(&A.wi_marray), - reinterpret_cast(&B.wi_marray), - reinterpret_cast(&C.wi_marray), + reinterpret_cast(&A.wi_marray), + reinterpret_cast(&B.wi_marray), + reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); } else if constexpr (std::is_same::value) { - auto ptrA = reinterpret_cast(&A.wi_marray); - auto ptrB = reinterpret_cast(&B.wi_marray); + auto ptrA = reinterpret_cast(&A.wi_marray); + auto ptrB = reinterpret_cast(&B.wi_marray); if constexpr (std::is_same::value) { __hmma_m32n8k16_mma_f32f32( reinterpret_cast(&D.wi_marray), ptrA, ptrB, - reinterpret_cast(&C.wi_marray), + reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); } else if constexpr (std::is_same::value) { __hmma_m32n8k16_mma_f16f16( reinterpret_cast(&D.wi_marray), ptrA, ptrB, - reinterpret_cast(&C.wi_marray), + reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); } } @@ -676,9 +677,9 @@ struct joint_matrix_mad_impl< get_layout_pair_id(), 0); } else if constexpr (std::is_same::value) { __dmma_m8n8k4_mma_f64(reinterpret_cast(&D.wi_marray), - reinterpret_cast(&A.wi_marray), - reinterpret_cast(&B.wi_marray), - reinterpret_cast(&C.wi_marray), + reinterpret_cast(&A.wi_marray), + reinterpret_cast(&B.wi_marray), + reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); } return D; @@ -691,13 +692,14 @@ struct joint_matrix_mad_impl< namespace experimental { namespace matrix { -template ::value || - (std::is_same::value && - std::is_same::value), - bool> = true> +template < + typename Group, typename S, typename T, matrix_use Use, size_t NumRows, + size_t NumCols, matrix_layout Layout, access::address_space Space, + std::enable_if_t>::value || + (std::is_same::value && + + std::is_same, float>::value), + bool> = true> void joint_matrix_load( Group sg, joint_matrix &res, multi_ptr src, size_t stride) {