From b68aeadc3c0792823772e080bae0b7ec6c914368 Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Tue, 19 Sep 2023 11:46:01 +0800 Subject: [PATCH 01/23] =?UTF-8?q?[Matrix]=20syntax=20changes=20as=20prepra?= =?UTF-8?q?ration=20before=20moving=20joint=20matrix=20from=20experimental?= =?UTF-8?q?=20namespace=20As=20part=20of=20the=20effort=20to=20move=20join?= =?UTF-8?q?t=20matrix=20from=20experimental=20namespace=20to=20supported.?= =?UTF-8?q?=20A=20review=20of=20the=20API=20is=20being=20done=20as=20part?= =?UTF-8?q?=20of=20https://github.com/intel/llvm/pull/7964.=20This=20resul?= =?UTF-8?q?ts=20in=20the=20following=20changes=20in=20the=20syntax:=201-?= =?UTF-8?q?=20Add=20Td=20to=20joint=5Fmatrix=5Fmad=20as=20Tc=20can=20be=20?= =?UTF-8?q?different=20from=20Td=20on=20the=20GPU,=20Now,=20we=20make=20D?= =?UTF-8?q?=20as=20an=20input=20argument=20to=20mad.=202-=20=20Change=20?= =?UTF-8?q?=E2=80=9Cpacked=E2=80=9D=20to=20ext=5Fintel=5Fpacked:=203-=20?= =?UTF-8?q?=20Move=20EWOps=20(get=5Fwi=5Fdata,=20wi=5Felement,=20get=5Fcoo?= =?UTF-8?q?rd)=20to=20detail=20namespace)=204-=20add=20const=20to=20joint?= =?UTF-8?q?=5Fmatrix=20in=20store=20and=20mad=205=20-=20add=20joint=5Fmatr?= =?UTF-8?q?ix=5Fcopy/assignment=20function=206-=20add=20apply=20with=20coo?= =?UTF-8?q?rdination=20(change=20existing=20tests)=207-=20change=20get=5Fc?= =?UTF-8?q?oord=20vector=20type=20from=20int32=5Ft=20to=20size=5Ft=208-=20?= =?UTF-8?q?delete=20explicitly=20both=20=3D=20and=20copy=20ctor.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../sycl/ext/oneapi/matrix/matrix-intel.hpp | 59 +++++++--- .../oneapi/matrix/matrix-unified-utils.hpp | 7 +- .../sycl/ext/oneapi/matrix/matrix-unified.hpp | 105 ++++++++++-------- .../Matrix/Legacy/element_wise_ops_impl.hpp | 2 +- .../Legacy/elemwise_irreg_size_ops_bf16.cpp | 2 +- .../Matrix/Legacy/joint_matrix_bf16_impl.hpp | 2 +- .../joint_matrix_bfloat16_32x64_impl.hpp | 2 +- ...trix_bfloat16_colmajorA_colmajorB_impl.hpp | 2 +- .../Legacy/joint_matrix_bfloat16_impl.hpp | 2 +- ...trix_bfloat16_rowmajorA_rowmajorB_impl.hpp | 2 +- .../Matrix/Legacy/joint_matrix_half_impl.hpp | 2 +- ...t_matrix_int8_colmajorA_colmajorB_impl.hpp | 2 +- .../Legacy/joint_matrix_int8_vnni_impl.hpp | 2 +- .../Legacy/joint_matrix_query_default.cpp | 2 +- .../Legacy/joint_matrix_ss_int8_impl.hpp | 2 +- .../Legacy/joint_matrix_su_int8_impl.hpp | 2 +- .../Legacy/joint_matrix_us_int8_impl.hpp | 2 +- .../Legacy/joint_matrix_uu_int8_impl.hpp | 2 +- .../test-e2e/Matrix/element_wise_abc_impl.hpp | 14 +-- .../Matrix/element_wise_all_ops_half_impl.hpp | 10 +- .../Matrix/element_wise_all_ops_impl.hpp | 4 +- .../Matrix/element_wise_all_ops_int8_impl.hpp | 10 +- .../element_wise_all_ops_int8_packed_impl.hpp | 40 +++---- .../Matrix/element_wise_all_ops_tf32_impl.hpp | 10 +- .../Matrix/element_wise_all_sizes_impl.hpp | 2 +- .../element_wise_irreg_sum_rows_impl.hpp | 8 +- .../test-e2e/Matrix/element_wise_ops_impl.hpp | 10 +- .../Matrix/elemwise_irreg_size_ops_bf16.cpp | 10 +- .../Matrix/joint_matrix_all_sizes_impl.hpp | 7 +- .../Matrix/joint_matrix_apply_cuda.hpp | 65 ++++++----- .../joint_matrix_bf16_fill_k_cache_impl.hpp | 53 ++++----- .../joint_matrix_bfloat16_32x64_impl.hpp | 7 +- .../joint_matrix_bfloat16_array_impl.hpp | 7 +- ...trix_bfloat16_colmajorA_colmajorB_impl.hpp | 2 +- .../Matrix/joint_matrix_bfloat16_impl.hpp | 7 +- ...trix_bfloat16_rowmajorA_rowmajorB_impl.hpp | 2 +- .../joint_matrix_colA_rowB_colC_impl.hpp | 2 +- .../Matrix/joint_matrix_gemm_cuda.hpp | 2 +- .../Matrix/joint_matrix_half_impl.hpp | 7 +- ...t_matrix_int8_colmajorA_colmajorB_impl.hpp | 2 +- .../Matrix/joint_matrix_int8_vnni_impl.hpp | 2 +- .../Matrix/joint_matrix_out_bounds_impl.hpp | 2 +- .../Matrix/joint_matrix_query_default.cpp | 5 +- .../Matrix/joint_matrix_ss_int8_impl.hpp | 7 +- .../Matrix/joint_matrix_su_int8_impl.hpp | 7 +- .../Matrix/joint_matrix_tf32_impl.hpp | 8 +- .../Matrix/joint_matrix_transposeC_impl.hpp | 7 +- .../Matrix/joint_matrix_us_int8_impl.hpp | 7 +- .../Matrix/joint_matrix_uu_int8_impl.hpp | 7 +- .../matrix/matrix-nvptx-bfloat16-test.cpp | 12 +- .../cuda/matrix/matrix-nvptx-double-test.cpp | 4 +- .../matrix/matrix-nvptx-half-float-test.cpp | 12 +- .../matrix/matrix-nvptx-half-half-test.cpp | 12 +- .../cuda/matrix/matrix-nvptx-int8-test.cpp | 12 +- .../cuda/matrix/matrix-nvptx-tf32-test.cpp | 4 +- .../cuda/matrix/matrix-nvptx-uint8-test.cpp | 12 +- .../matrix/matrix_load_store_as.cpp | 7 +- .../matrix/matrix_load_store_as_legacy.cpp | 2 +- .../matrix/legacy/matrix-bf16-test-SG-16.cpp | 2 +- sycl/test/matrix/legacy/matrix-bf16-test.cpp | 2 +- .../matrix/legacy/matrix-bfloat16-test.cpp | 2 +- .../matrix/legacy/matrix-elemwise-ops.cpp | 2 +- .../matrix/legacy/matrix-int8-test-SG-16.cpp | 2 +- sycl/test/matrix/legacy/matrix-int8-test.cpp | 2 +- .../matrix-bfloat16-test-coord-basicB.cpp | 8 +- sycl/test/matrix/matrix-bfloat16-test.cpp | 5 +- sycl/test/matrix/matrix-elemwise-ops.cpp | 8 +- sycl/test/matrix/matrix-int8-test.cpp | 5 +- sycl/test/matrix/matrix-tf32-test.cpp | 5 +- sycl/test/matrix/query-use.cpp | 6 +- 70 files changed, 361 insertions(+), 299 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp index 9de4f4ec1851f..eac624a9c3360 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp @@ -29,10 +29,6 @@ namespace sycl { inline namespace _V1 { namespace ext { -namespace intel::experimental::matrix::layout { -constexpr sycl::ext::oneapi::experimental::matrix::layout packed = - static_cast(2); -} namespace oneapi { namespace experimental { namespace matrix { @@ -48,8 +44,7 @@ template struct spv_matrix_layout_traits { SPV_MATRIX_LAYOUT_TRAITS(layout::row_major, __spv::MatrixLayout::RowMajor) SPV_MATRIX_LAYOUT_TRAITS(layout::col_major, __spv::MatrixLayout::ColumnMajor) -SPV_MATRIX_LAYOUT_TRAITS(sycl::ext::intel::experimental::matrix::layout::packed, - __spv::MatrixLayout::Packed) +SPV_MATRIX_LAYOUT_TRAITS(layout::ext_intel_packed, __spv::MatrixLayout::Packed) SPV_MATRIX_LAYOUT_TRAITS(layout::dynamic, __spv::MatrixLayout::Dynamic) template struct spv_matrix_use_traits { @@ -94,10 +89,6 @@ struct jm_type_interpretation_helper_trait< using element_type = sycl::ext::oneapi::experimental::matrix::precision::tf32; using storage_element_type = float; }; -} // namespace detail -} // namespace oneapi - -namespace intel::experimental::matrix { using namespace sycl::ext::oneapi::experimental::matrix; // Begin wi_element definition @@ -121,12 +112,12 @@ class wi_element { std::size_t i) : M(Mat), idx(i) {} - inline __SYCL_ALWAYS_INLINE std::tuple get_coord() { + inline __SYCL_ALWAYS_INLINE std::tuple get_coord() { #if defined(__SYCL_DEVICE_ONLY__) __ocl_vec_t coord = __spirv_JointMatrixGetElementCoordINTEL(M.spvm, idx); - const uint32_t row = coord[0]; - const uint32_t col = coord[1]; + const size_t row = coord[0]; + const size_t col = coord[1]; return std::make_tuple(row, col); #else throw runtime_error("joint matrix is not supported on host device.", @@ -479,7 +470,10 @@ get_wi_data(Group sg, sycl::ext::oneapi::experimental::matrix::joint_matrix< } // End wi_data definition +} // namespace detail +} // namespace oneapi +namespace intel::experimental::matrix { template < typename Group, typename T, typename Tp, sycl::ext::oneapi::experimental::matrix::use Use, size_t NumRows, @@ -490,7 +484,7 @@ template < bool> = true> inline __SYCL_ALWAYS_INLINE void joint_matrix_store(Group sg, - sycl::ext::oneapi::experimental::matrix::joint_matrix< + const sycl::ext::oneapi::experimental::matrix::joint_matrix< Group, Tp, Use, NumRows, NumCols, Layout> &src, multi_ptr dst, size_t stride) { #if defined(__SYCL_DEVICE_ONLY__) @@ -528,6 +522,43 @@ joint_matrix_store(Group sg, PI_ERROR_INVALID_DEVICE); #endif // defined(__SYCL_DEVICE_ONLY__) } + +template +inline __SYCL_ALWAYS_INLINE void joint_matrix_apply( + Group sg, + sycl::ext::oneapi::experimental::matrix::joint_matrix &jm, + F &&lambda) { +#if defined(__SYCL_DEVICE_ONLY__) +#if defined(__NVPTX__) + std::ignore = sg; + for (int i = 0; i < jm.cuda_impl.wi_marray.size(); i++) { + lambda(jm.cuda_impl.wi_marray[i]); + } +#else // NVPTX + using storage_element_type = + typename oneapi::detail::jm_type_interpretation_helper_trait< + T>::storage_element_type; + auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, jm); + for (int i = 0; i < wi_data_c.length(); i++) { + storage_element_type element = wi_data_c[i]; + auto [row, col] = wi_data_c[i].get_coord(); + lambda(element, row, col); + wi_data_c[i] = element; + } +#endif +#else + std::ignore = sg; + std::ignore = jm; + std::ignore = lambda; + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif + return; +} + } // namespace intel::experimental::matrix } // namespace ext diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp index f51e146fd9a0c..8a9dbc12df2ec 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp @@ -16,7 +16,12 @@ namespace matrix { enum class use { a, b, accumulator }; -enum class layout { row_major = 0, col_major = 1, dynamic = 3 }; +enum class layout { + row_major = 0, + col_major = 1, + ext_intel_packed = 2, + dynamic = 3 +}; namespace precision { class tf32 { diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp index 7b101b18cea90..5fc6290c1f71d 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -61,19 +61,8 @@ struct joint_matrix { } #ifdef __SYCL_DEVICE_ONLY__ #if defined(__SPIR__) - // Generate a non-trivial assignment operator and copy c'tor that prevents - // memcpy from being generated. - // TODO: to remove, when either IGC can handle alloca JointMatrix or - // combination of InstCombine + SROA + mem2reg can remove it - joint_matrix(const joint_matrix &other) { - spvm = other.spvm; - return *this; - } - - joint_matrix &operator=(const joint_matrix &rhs) { - spvm = rhs.spvm; - return *this; - } + joint_matrix(const joint_matrix &other) = delete; + joint_matrix &operator=(const joint_matrix &rhs) = delete; #endif // defined(__SPIR__) #endif }; @@ -99,7 +88,7 @@ class wi_data { return jm.cuda_impl.wi_marray.size(); #else throw runtime_error("get_wi_data is available using: " - "ext::intel::experimental::matrix::get_wi_data.", + "ext::oneapi::detail::get_wi_data.", PI_ERROR_INVALID_DEVICE); #endif }; @@ -109,7 +98,7 @@ class wi_data { return (jm.cuda_impl.wi_marray[i]); #else throw runtime_error("get_wi_data is available using: " - "ext::intel::experimental::matrix::get_wi_data.", + "ext::oneapi::detail::get_wi_data.", PI_ERROR_INVALID_DEVICE); #endif }; @@ -138,9 +127,9 @@ template &jm, using storage_element_type = typename oneapi::detail::jm_type_interpretation_helper_trait< T>::storage_element_type; - auto wi_data_c = sycl::ext::intel::experimental::matrix::get_wi_data(sg, jm); + auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, jm); for (int i = 0; i < wi_data_c.length(); i++) { storage_element_type element = wi_data_c[i]; lambda(element); @@ -262,7 +251,7 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_load( Ptr, stride, __spv::MatrixLayout::ColumnMajor, spv_scope_traits::value); break; - case sycl::ext::intel::experimental::matrix::layout::packed: + case sycl::ext::oneapi::experimental::matrix::layout::ext_intel_packed: res.spvm = __spirv_JointMatrixLoadINTEL< DecorT, S, NumRows, NumCols, spv_matrix_use_traits::value, @@ -327,8 +316,9 @@ template inline __SYCL_ALWAYS_INLINE void joint_matrix_store( Group sg, - joint_matrix &src, + const joint_matrix + &src, multi_ptr dst, size_t stride, sycl::ext::oneapi::experimental::matrix::layout Layout) { #if defined(__SYCL_DEVICE_ONLY__) @@ -361,7 +351,7 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store( Ptr, src.spvm, stride, __spv::MatrixLayout::ColumnMajor, spv_scope_traits::value); break; - case sycl::ext::intel::experimental::matrix::layout::packed: + case sycl::ext::oneapi::experimental::matrix::layout::ext_intel_packed: __spirv_JointMatrixStoreINTEL< DecorT, T, NumRows, NumCols, spv_matrix_use_traits::value, @@ -382,53 +372,78 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store( #endif // defined(__SYCL_DEVICE_ONLY__) } -template -inline __SYCL_ALWAYS_INLINE - joint_matrix - joint_matrix_mad( - Group sg, joint_matrix &A, - joint_matrix &B, - joint_matrix - &C) { +template +inline __SYCL_ALWAYS_INLINE void joint_matrix_mad( + Group sg, const joint_matrix &A, + const joint_matrix &B, + const joint_matrix + &C, + joint_matrix &D) { #if defined(__SYCL_DEVICE_ONLY__) #if defined(__NVPTX__) std::ignore = sg; if constexpr (std::is_same::value) { - joint_matrix - D; sycl::ext::oneapi::detail::joint_matrix_mad_cuda( D.cuda_impl, A.cuda_impl, B.cuda_impl, C.cuda_impl); - return D; } else { assert(false && "Ta != Tb : In the CUDA backend joint_matrix_mad " "requires that joint_matrix data types Ta and Tb match"); } #else - joint_matrix res; if constexpr (std::is_same::value && std::is_same::value && std::is_same::value) - res.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm); + D.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm); else if constexpr (std::is_unsigned::value && std::is_unsigned::value) - res.spvm = __spirv_JointMatrixUUMadINTEL(A.spvm, B.spvm, C.spvm); + D.spvm = __spirv_JointMatrixUUMadINTEL(A.spvm, B.spvm, C.spvm); else if constexpr (std::is_signed::value && std::is_unsigned::value) - res.spvm = __spirv_JointMatrixSUMadINTEL(A.spvm, B.spvm, C.spvm); + D.spvm = __spirv_JointMatrixSUMadINTEL(A.spvm, B.spvm, C.spvm); else if constexpr (std::is_unsigned::value && std::is_signed::value) - res.spvm = __spirv_JointMatrixUSMadINTEL(A.spvm, B.spvm, C.spvm); + D.spvm = __spirv_JointMatrixUSMadINTEL(A.spvm, B.spvm, C.spvm); else - res.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm); - return res; + D.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm); #endif // defined(__NVPTX__) #else std::ignore = sg; std::ignore = A; std::ignore = B; std::ignore = C; + std::ignore = D; + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) +} + +template +void joint_matrix_copy(Group sg, + joint_matrix &src, + joint_matrix &dst) { +#if defined(__SYCL_DEVICE_ONLY__) +#if defined(__NVPTX__) + std::ignore = sg; + for (int i = 0; i < src.cuda_impl.wi_marray.size(); i++) { + dest.cuda_impl.wi_marray[i] = src.cuda_impl.wi_marray[i]; + } +#else + using storage_element_type = + typename oneapi::detail::jm_type_interpretation_helper_trait< + T2>::storage_element_type; + auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, src); + auto wi_data_dst = sycl::ext::oneapi::detail::get_wi_data(sg, dst); + for (int i = 0; i < wi_data_c.length(); i++) { + wi_data_dst[i] = static_cast(wi_data_c[i]); + } +#endif // defined(__NVPTX__) +#else + std::ignore = sg; + std::ignore = dst; + std::ignore = src; throw runtime_error("joint matrix is not supported on host device.", PI_ERROR_INVALID_DEVICE); #endif // defined(__SYCL_DEVICE_ONLY__) diff --git a/sycl/test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp b/sycl/test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp index 8d15b78fd3198..7179b82855f50 100644 --- a/sycl/test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp @@ -76,7 +76,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); } auto wi_slice_c = sub_c.get_wi_data(); for (int i = 0; i < wi_slice_c.length(); i++) { diff --git a/sycl/test-e2e/Matrix/Legacy/elemwise_irreg_size_ops_bf16.cpp b/sycl/test-e2e/Matrix/Legacy/elemwise_irreg_size_ops_bf16.cpp index 0f57377c571ac..d041cd6050d73 100644 --- a/sycl/test-e2e/Matrix/Legacy/elemwise_irreg_size_ops_bf16.cpp +++ b/sycl/test-e2e/Matrix/Legacy/elemwise_irreg_size_ops_bf16.cpp @@ -100,7 +100,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k) * (N) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); } auto wi_slice_c = sub_c.get_wi_data(); for (int i = 0; i < wi_slice_c.length(); i++) { diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bf16_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bf16_impl.hpp index 9868aef0d92e2..3ba9ae346e070 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bf16_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bf16_impl.hpp @@ -76,7 +76,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k) * (N) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_32x64_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_32x64_impl.hpp index ac4a0bc405816..8835bce054171 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_32x64_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_32x64_impl.hpp @@ -68,7 +68,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp index 91845ac61a180..48c4a894ab385 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp @@ -62,7 +62,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (sg_starty / SG_SZ * TN) * K + k * TK, K, matrix_layout::col_major); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_impl.hpp index 2598905f9f6fe..db16f05673321 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_impl.hpp @@ -68,7 +68,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp index c293d8ff22944..5676cd849e1b5 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp @@ -62,7 +62,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK) * (N) + sg_starty / SG_SZ * TN, N, matrix_layout::row_major); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_half_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_half_impl.hpp index c663cc282c758..658985c10ab5c 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_half_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_half_impl.hpp @@ -73,7 +73,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_colmajorA_colmajorB_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_colmajorA_colmajorB_impl.hpp index 21347c80c083b..85bffe8957e77 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_colmajorA_colmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_colmajorA_colmajorB_impl.hpp @@ -72,7 +72,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (sg_starty / SG_SZ * TN) * K + k * TK, K, matrix_layout::col_major); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_vnni_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_vnni_impl.hpp index 1948071dbf405..2cb8196a49306 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_vnni_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_vnni_impl.hpp @@ -64,7 +64,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK) * N + sg_starty / SG_SZ * TN, N, matrix_layout::row_major); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_query_default.cpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_query_default.cpp index 8aaf737a274a8..6b7878b5cce7d 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_query_default.cpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_query_default.cpp @@ -97,7 +97,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp index a2436bc56e792..ef594dc6bc3a8 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp @@ -70,7 +70,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_su_int8_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_su_int8_impl.hpp index f0a9a7155fb0d..0eb5bc2bc8d58 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_su_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_su_int8_impl.hpp @@ -76,7 +76,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_us_int8_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_us_int8_impl.hpp index 68cf40bb481b9..3d0067cbb7b36 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_us_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_us_int8_impl.hpp @@ -78,7 +78,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp index 14190434dd2b1..be4eee8452c3c 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp @@ -76,7 +76,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/element_wise_abc_impl.hpp b/sycl/test-e2e/Matrix/element_wise_abc_impl.hpp index 8b7ee3af2b9c5..1da92f4fd4dbc 100644 --- a/sycl/test-e2e/Matrix/element_wise_abc_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_abc_impl.hpp @@ -55,8 +55,9 @@ void matrix_elem_wise_ops(big_matrix &C, big_matrix &A, sub_group sg = spmd_item.get_sub_group(); joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix + joint_matrix< + sub_group, T2, use::b, TK, TN, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b; joint_matrix sub_c; @@ -65,8 +66,7 @@ void matrix_elem_wise_ops(big_matrix &C, big_matrix &A, accA.template get_multi_ptr() + (sg_startx * TM) * K, K); - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); + auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] += 1; } @@ -76,8 +76,7 @@ void matrix_elem_wise_ops(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + sg_starty / SG_SZ * TN * vnniFactor, N * vnniFactor); - auto wi_slice_b = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); + auto wi_slice_b = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); for (int i = 0; i < wi_slice_b.length(); i++) { wi_slice_b[i] += 1; } @@ -87,8 +86,7 @@ void matrix_elem_wise_ops(big_matrix &C, big_matrix &A, accC.template get_multi_ptr() + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N, layout::row_major); - auto wi_slice_c = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); + auto wi_slice_c = sycl::ext::oneapi::detail::get_wi_data(sg, sub_c); for (int i = 0; i < wi_slice_c.length(); i++) { wi_slice_c[i] += 1; } diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_half_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_half_impl.hpp index 540d75c245815..063acee4c2ffe 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_half_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_half_impl.hpp @@ -42,7 +42,7 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); + sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] + static_cast(2); } @@ -77,7 +77,7 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); + sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] - static_cast(2); } @@ -112,7 +112,7 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); + sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] * static_cast(3.0); } @@ -147,7 +147,7 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 4); auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); + sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] / static_cast(2.0); } @@ -182,7 +182,7 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); + sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { if (wi_slice_a[i]) { if (wi_slice_a[i] > static_cast(2.0) || diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp index 8e15488e151a0..25d77269b03a4 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp @@ -64,7 +64,7 @@ void verify_op_a(const T l, const T r, const float ref, OP op) { sub_mat; joint_matrix_fill(sg, sub_mat, l); auto wi_slice = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_mat); + sycl::ext::oneapi::detail::get_wi_data(sg, sub_mat); for (int i = 0; i < wi_slice.length(); i++) { wi_slice[i] = op(wi_slice[i], r); } @@ -105,7 +105,7 @@ void verify_op_c(const T l, const T r, const float ref, OP op) { sub_mat; joint_matrix_fill(sg, sub_mat, l); auto wi_slice = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_mat); + sycl::ext::oneapi::detail::get_wi_data(sg, sub_mat); for (int i = 0; i < wi_slice.length(); i++) { wi_slice[i] = op(wi_slice[i], r); } diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_int8_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_int8_impl.hpp index 803ebe0addb3a..5d1668b753baa 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_int8_impl.hpp @@ -41,7 +41,7 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); + sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] + 2; } @@ -76,7 +76,7 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); + sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] - 2; } @@ -111,7 +111,7 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); + sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] * 3; } @@ -146,7 +146,7 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 4); auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); + sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] / 2; } @@ -181,7 +181,7 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); + sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { if (wi_slice_a[i]) { if (wi_slice_a[i] > 2 || wi_slice_a[i] >= 2 || diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_int8_packed_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_int8_packed_impl.hpp index ce89a04b4168c..a04d464605fef 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_int8_packed_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_int8_packed_impl.hpp @@ -36,14 +36,14 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, const auto sg_starty = global_idy - spmd_item.get_local_id(1); sub_group sg = spmd_item.get_sub_group(); - joint_matrix + joint_matrix< + sub_group, int8_t, use::b, TK, TN, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b; joint_matrix_fill(sg, sub_b, 5); - auto wi_slice_b = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); + auto wi_slice_b = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); for (int i = 0; i < wi_slice_b.length(); i++) { wi_slice_b[i] = wi_slice_b[i] + 2; } @@ -73,14 +73,14 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, const auto sg_starty = global_idy - spmd_item.get_local_id(1); sub_group sg = spmd_item.get_sub_group(); - joint_matrix + joint_matrix< + sub_group, int8_t, use::b, TK, TN, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b; joint_matrix_fill(sg, sub_b, 5); - auto wi_slice_b = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); + auto wi_slice_b = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); for (int i = 0; i < wi_slice_b.length(); i++) { wi_slice_b[i] = wi_slice_b[i] - 2; } @@ -110,14 +110,14 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, const auto sg_starty = global_idy - spmd_item.get_local_id(1); sub_group sg = spmd_item.get_sub_group(); - joint_matrix + joint_matrix< + sub_group, int8_t, use::b, TK, TN, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b; joint_matrix_fill(sg, sub_b, 5); - auto wi_slice_b = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); + auto wi_slice_b = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); for (int i = 0; i < wi_slice_b.length(); i++) { wi_slice_b[i] = wi_slice_b[i] * 3; } @@ -147,14 +147,14 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, const auto sg_starty = global_idy - spmd_item.get_local_id(1); sub_group sg = spmd_item.get_sub_group(); - joint_matrix + joint_matrix< + sub_group, int8_t, use::b, TK, TN, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b; joint_matrix_fill(sg, sub_b, 4); - auto wi_slice_b = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); + auto wi_slice_b = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); for (int i = 0; i < wi_slice_b.length(); i++) { wi_slice_b[i] = wi_slice_b[i] / 2; } @@ -184,14 +184,14 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, const auto sg_starty = global_idy - spmd_item.get_local_id(1); sub_group sg = spmd_item.get_sub_group(); - joint_matrix + joint_matrix< + sub_group, int8_t, use::b, TK, TN, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b; joint_matrix_fill(sg, sub_b, 5); - auto wi_slice_b = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); + auto wi_slice_b = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); for (int i = 0; i < wi_slice_b.length(); i++) { if (wi_slice_b[i]) { if (wi_slice_b[i] > 2 || wi_slice_b[i] >= 2 || diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_tf32_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_tf32_impl.hpp index 27eacf89c748a..96fcfa975b408 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_tf32_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_tf32_impl.hpp @@ -42,7 +42,7 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, round_to_tf32(5.0)); auto wi_slice_a = - ext::intel::experimental::matrix::get_wi_data(sg, sub_a); + ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] + 2; } @@ -78,7 +78,7 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, round_to_tf32(5.0)); auto wi_slice_a = - ext::intel::experimental::matrix::get_wi_data(sg, sub_a); + ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] - round_to_tf32(2); } @@ -112,7 +112,7 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, round_to_tf32(5.0)); auto wi_slice_a = - ext::intel::experimental::matrix::get_wi_data(sg, sub_a); + ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] * round_to_tf32(3.0); } @@ -147,7 +147,7 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, round_to_tf32(4.0)); auto wi_slice_a = - ext::intel::experimental::matrix::get_wi_data(sg, sub_a); + ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] / round_to_tf32(2.0); } @@ -181,7 +181,7 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, round_to_tf32(5.0)); auto wi_slice_a = - ext::intel::experimental::matrix::get_wi_data(sg, sub_a); + ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { if (wi_slice_a[i]) { if (wi_slice_a[i] > 2.0 || wi_slice_a[i] >= 2.0 || diff --git a/sycl/test-e2e/Matrix/element_wise_all_sizes_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_sizes_impl.hpp index c49f9b57e2f32..b18ca01193974 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_sizes_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_sizes_impl.hpp @@ -62,7 +62,7 @@ void matrix_verify_add(const T1 val1, const T1 val2, const T1 result) { joint_matrix_fill(sg, sub_a, val1); auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); + sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] + val2; } diff --git a/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp b/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp index cfce95cba269f..6e3a86d7cb77b 100644 --- a/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp @@ -44,8 +44,9 @@ void matrix_sum_rows(queue q, big_matrix &B, nd_range<2> &r) { sycl::sub_group sg = spmd_item.get_sub_group(); - joint_matrix + joint_matrix< + sub_group, T, use::b, TK, TN, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b; joint_matrix_load( @@ -57,8 +58,7 @@ void matrix_sum_rows(queue q, big_matrix &B, nd_range<2> &r) { // (tK/4) int32_t sum_local_rows[M] = {0}; // 8 local rows, M total // sub_b has 32x8 elements, 32 elements per WI, 4 per WI per row - auto data = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); + auto data = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); // each WI calculates local sum of rows for (int row = 0; row < TK / 4; row++) { // there are 8 rows diff --git a/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp b/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp index 1206b556339a9..edf9f162d5c3a 100644 --- a/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp @@ -51,8 +51,9 @@ void matrix_multiply(big_matrix &C, joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix + joint_matrix< + sub_group, int8_t, use::b, TK, TN, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b; joint_matrix sub_c; @@ -72,10 +73,9 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } - auto wi_slice_c = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); + auto wi_slice_c = sycl::ext::oneapi::detail::get_wi_data(sg, sub_c); for (int i = 0; i < wi_slice_c.length(); i++) { wi_slice_c[i] *= 2; } diff --git a/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp b/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp index 8e2865de207b4..08724fbb48e9d 100644 --- a/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp +++ b/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp @@ -80,8 +80,9 @@ void matrix_multiply(big_matrix &C, joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix + joint_matrix< + sub_group, bfloat16, use::b, TK, TN, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b; joint_matrix sub_c; joint_matrix_load( @@ -101,10 +102,9 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k) * (N) + sg_starty / SG_SZ * TN * 2, N * 2); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } - auto wi_slice_c = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); + auto wi_slice_c = sycl::ext::oneapi::detail::get_wi_data(sg, sub_c); for (int i = 0; i < wi_slice_c.length(); i++) { wi_slice_c[i] += 5.0; } diff --git a/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp index 469837cd26d41..f4572990ded76 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp @@ -56,8 +56,9 @@ void matrix_multiply(big_matrix &C, big_matrix &A, sub_group sg = spmd_item.get_sub_group(); joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix + joint_matrix< + sub_group, T2, use::b, TK, TN, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b; joint_matrix sub_c; @@ -78,7 +79,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, (k * TK / vnniFactor) * (N * vnniFactor) + sg_starty / SG_SZ * TN * vnniFactor, N * vnniFactor); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp b/sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp index 4303239cefe32..b58e0a1b7c467 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp @@ -51,34 +51,34 @@ void matrix_verify_lambda(queue q, q.submit([&](handler &cgh) { accessor accC(bufC, cgh); - cgh.parallel_for>(r, [ - accC, lambda - ](nd_item<2> spmd_item)[[sycl::reqd_sub_group_size(SG_SZ)]] { - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx - spmd_item.get_local_id(0); - const auto sg_starty = global_idy - spmd_item.get_local_id(1); - - auto sg = spmd_item.get_sub_group(); - - joint_matrix sub_a; - joint_matrix sub_b; - joint_matrix sub_c; - - joint_matrix_fill(sg, sub_a, 3); - joint_matrix_fill(sg, sub_b, 1); - joint_matrix_fill(sg, sub_c, -80); - - joint_matrix_apply(sg, sub_a, lambda); - - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); - - joint_matrix_store( - sg, sub_c, - accC.template get_multi_ptr() + - (sg_startx * M) * (N * nWGperDim) + sg_starty / SG_SZ * N, - (N * nWGperDim), layout::row_major); - }); // parallel for + cgh.parallel_for>( + r, [accC, lambda]( + nd_item<2> spmd_item) [[sycl::reqd_sub_group_size(SG_SZ)]] { + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + auto sg = spmd_item.get_sub_group(); + + joint_matrix sub_a; + joint_matrix sub_b; + joint_matrix sub_c; + + joint_matrix_fill(sg, sub_a, 3); + joint_matrix_fill(sg, sub_b, 1); + joint_matrix_fill(sg, sub_c, -80); + + joint_matrix_apply(sg, sub_a, lambda); + + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * M) * (N * nWGperDim) + sg_starty / SG_SZ * N, + (N * nWGperDim), layout::row_major); + }); // parallel for }); } assert_ref(C.get_data(), ref); @@ -111,8 +111,8 @@ void matrix_verify_op(queue q, big_matrix &C, cgh); cgh.parallel_for>( - r, [ accC, - Op ](nd_item<2> spmd_item)[[sycl::reqd_sub_group_size(SG_SZ)]] { + r, [accC, + Op](nd_item<2> spmd_item) [[sycl::reqd_sub_group_size(SG_SZ)]] { const auto global_idx = spmd_item.get_global_id(0); const auto global_idy = spmd_item.get_global_id(1); const auto sg_startx = global_idx - spmd_item.get_local_id(0); @@ -154,7 +154,7 @@ void matrix_verify_op(queue q, big_matrix &C, } } - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); joint_matrix_store( sg, sub_c, @@ -162,8 +162,7 @@ void matrix_verify_op(queue q, big_matrix &C, (sg_startx * M) * (N * nWGperDim) + sg_starty / SG_SZ * N, (N * nWGperDim), layout::row_major); }); // parallel for - }) - .wait(); + }).wait(); } assert_ops_ref(C.get_data(), ref); } diff --git a/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp index f2d359cfe130b..74af36f0238ef 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp @@ -147,36 +147,37 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i) { #endif ; - joint_matrix + joint_matrix< + sub_group, TOperand, use::b, tK, tN, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> tB[NCACHE1 / tN][KCACHE2 / KCACHE1] #ifdef INIT_LIST = { - joint_matrix< - sub_group, TOperand, use::b, tK, tN, - ext::intel::experimental::matrix::layout::packed>(), - joint_matrix< - sub_group, TOperand, use::b, tK, tN, - ext::intel::experimental::matrix::layout::packed>(), - joint_matrix< - sub_group, TOperand, use::b, tK, tN, - ext::intel::experimental::matrix::layout::packed>(), - joint_matrix< - sub_group, TOperand, use::b, tK, tN, - ext::intel::experimental::matrix::layout::packed>(), - joint_matrix< - sub_group, TOperand, use::b, tK, tN, - ext::intel::experimental::matrix::layout::packed>(), - joint_matrix< - sub_group, TOperand, use::b, tK, tN, - ext::intel::experimental::matrix::layout::packed>(), - joint_matrix< - sub_group, TOperand, use::b, tK, tN, - ext::intel::experimental::matrix::layout::packed>(), - joint_matrix< - sub_group, TOperand, use::b, tK, tN, - ext::intel::experimental::matrix::layout::packed>(), + joint_matrix(), + joint_matrix(), + joint_matrix(), + joint_matrix(), + joint_matrix(), + joint_matrix(), + joint_matrix(), + joint_matrix(), } #endif ; diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64_impl.hpp index cc0196660744a..c27df0c01caa4 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64_impl.hpp @@ -46,8 +46,9 @@ void matrix_multiply(big_matrix &C, big_matrix &A, joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix + joint_matrix< + sub_group, bfloat16, use::b, TK, TN, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b; joint_matrix sub_c; @@ -68,7 +69,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_array_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_array_impl.hpp index d6390d8061dcc..9cdf9b2435f82 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_array_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_array_impl.hpp @@ -59,8 +59,9 @@ void matrix_multiply(big_matrix &C, big_matrix &A, sub_a[JM_ARRAY_SZ]; // For B, we assume B has been already VNNIed. - joint_matrix + joint_matrix< + sub_group, bfloat16, use::b, TK, TN, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b; joint_matrix sub_c[JM_ARRAY_SZ]; @@ -81,7 +82,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accA.template get_multi_ptr() + (sg_startx * TM * JM_ARRAY_SZ + TM * i) * K + k * TK, K); - sub_c[i] = joint_matrix_mad(sg, sub_a[i], sub_b, sub_c[i]); + joint_matrix_mad(sg, sub_a[i], sub_b, sub_c[i], sub_c[i]); } } diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp index 4847e093127a8..2c627d8b88e31 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp @@ -64,7 +64,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (sg_starty / SG_SZ * TN) * K + k * TK, K); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp index 76ac69da27677..a35dd11ee8cfa 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp @@ -45,8 +45,9 @@ void matrix_multiply(big_matrix &C, big_matrix &A, joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix + joint_matrix< + sub_group, bfloat16, use::b, TK, TN, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b; joint_matrix sub_c; @@ -66,7 +67,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp index 4d61a733e5927..08341b3835d57 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp @@ -64,7 +64,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK) * (N) + sg_starty / SG_SZ * TN, N); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_colA_rowB_colC_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_colA_rowB_colC_impl.hpp index f75da1824d94b..a90a46e258452 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_colA_rowB_colC_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_colA_rowB_colC_impl.hpp @@ -48,7 +48,7 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q) { joint_matrix_load(sg, sub_a, pA + (sg_startx * TM) * K + k, K); joint_matrix_load(sg, sub_b, pB + k * N + sg_starty / SG_SZ * TN, N); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } joint_matrix_store( sg, sub_c, pC + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N, diff --git a/sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp b/sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp index 5e451d45d7727..244311f7cc9ee 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp @@ -178,7 +178,7 @@ void test(queue &q) { } } - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp index f92548d2f7ed8..fcfb87545ff6b 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp @@ -50,8 +50,9 @@ void matrix_multiply(big_matrix &C, joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix + joint_matrix< + sub_group, half, use::b, TK, TN, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b; joint_matrix sub_c; @@ -71,7 +72,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_int8_colmajorA_colmajorB_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_int8_colmajorA_colmajorB_impl.hpp index 6111f503007f5..a5b1fddcbbbb8 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_int8_colmajorA_colmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_int8_colmajorA_colmajorB_impl.hpp @@ -64,7 +64,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (sg_starty / SG_SZ * TN) * K + k * TK, K); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_int8_vnni_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_int8_vnni_impl.hpp index b6fe3f0376ffd..e042a8e282b56 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_int8_vnni_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_int8_vnni_impl.hpp @@ -65,7 +65,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK) * N + sg_starty / SG_SZ * TN, N); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp index 1c1c4f97819bf..46c49f6576e44 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp @@ -58,7 +58,7 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q, unsigned int vnniFactor) { joint_matrix_load_checked( sg, sub_b, pB + k * N + sg_starty / SG_SZ * TN * vnniFactor, N * vnniFactor, K / vnniFactor, N * vnniFactor); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); } // bounds-checked store where width and height are added joint_matrix_store_checked( diff --git a/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp b/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp index 048aed6341f6c..b44661a1f269f 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp +++ b/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp @@ -78,7 +78,8 @@ void matrix_multiply(big_matrix &C, myparams2::joint_matrix_a sub_a; myparams2::joint_matrix_b< - sub_group, ext::intel::experimental::matrix::layout::packed> + sub_group, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b; myparams2::joint_matrix_accumulator sub_c; @@ -99,7 +100,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_ss_int8_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_ss_int8_impl.hpp index 5cca6572cef21..5820544722f55 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_ss_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_ss_int8_impl.hpp @@ -51,8 +51,9 @@ void matrix_multiply(big_matrix &C, joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix + joint_matrix< + sub_group, int8_t, use::b, TK, TN, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b; joint_matrix sub_c; @@ -68,7 +69,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_su_int8_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_su_int8_impl.hpp index 397fcc9a5aa97..628735e8523e2 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_su_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_su_int8_impl.hpp @@ -51,8 +51,9 @@ void matrix_multiply(big_matrix &C, joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix + joint_matrix< + sub_group, uint8_t, use::b, TK, TN, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b; joint_matrix sub_c; @@ -72,7 +73,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_tf32_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_tf32_impl.hpp index 4d4ba0ee951e9..e8a4ece3bb00f 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_tf32_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_tf32_impl.hpp @@ -76,15 +76,13 @@ void matrix_multiply(big_matrix &C, // function will work on truncated floats. joint_matrix_apply(sg, sub_a, [=](float x) { x = round_to_tf32(x); }); - auto wi_data_b = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); + auto wi_data_b = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); for (int i = 0; i < wi_data_b.length(); i++) { wi_data_b[i] = round_to_tf32(wi_data_b[i]); } - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); + auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); joint_matrix_apply(sg, sub_a, [=](float x) { x *= 2; }); joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_transposeC_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_transposeC_impl.hpp index 24f6cce4cc09d..8634efde40c74 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_transposeC_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_transposeC_impl.hpp @@ -42,8 +42,9 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q, unsigned int vnniFactor) { // For B, since current implementation does not support non-packed // layout, users need to specify the packed_b layout. - joint_matrix + joint_matrix< + sub_group, bfloat16, use::b, TK, TN, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b; joint_matrix sub_c; joint_matrix_load(sg, sub_c, @@ -55,7 +56,7 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q, unsigned int vnniFactor) { joint_matrix_load(sg, sub_b, pB + k * N + sg_starty / SG_SZ * TN * vnniFactor, N * vnniFactor); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } joint_matrix_store( sg, sub_c, pC + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N, diff --git a/sycl/test-e2e/Matrix/joint_matrix_us_int8_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_us_int8_impl.hpp index 1d82f8833aba6..493ece5bc7d5e 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_us_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_us_int8_impl.hpp @@ -53,8 +53,9 @@ void matrix_multiply(big_matrix &C, joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix + joint_matrix< + sub_group, int8_t, use::b, TK, TN, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b; joint_matrix sub_c; @@ -75,7 +76,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_uu_int8_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_uu_int8_impl.hpp index e400b6694e4a9..d668fa604d395 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_uu_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_uu_int8_impl.hpp @@ -51,8 +51,9 @@ void matrix_multiply(big_matrix &C, joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix + joint_matrix< + sub_group, uint8_t, use::b, TK, TN, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b; joint_matrix sub_c; @@ -73,7 +74,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp index 80b67b14a55ac..4790a1dcb1bf6 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp @@ -57,7 +57,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -88,7 +88,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -119,7 +119,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -150,7 +150,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -181,7 +181,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -212,7 +212,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp index 31f77dc55b16f..30704f8869778 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp @@ -67,7 +67,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), N); //CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.row.row.f64(double {{.*}}, double {{.*}}, double {{.*}}, double {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.row.stride.f64.p1(ptr addrspace(1) %{{.*}}, double {{.*}}, double {{.*}}, i32 8) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -98,7 +98,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), K); //CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.col.col.f64(double {{.*}}, double {{.*}}, double {{.*}}, double {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64.p1(ptr addrspace(1) %{{.*}}, double {{.*}}, double {{.*}}, i32 8) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp index 7c24179022d55..3331c66d302b6 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp @@ -56,7 +56,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.row.row.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -87,7 +87,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.col.col.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -118,7 +118,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.row.row.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -149,7 +149,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.col.col.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -180,7 +180,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.row.row.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -211,7 +211,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.col.col.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp index 0e0b4ce903be2..e4250952fa0c2 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp @@ -56,7 +56,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -87,7 +87,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -118,7 +118,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -149,7 +149,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -180,7 +180,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -211,7 +211,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp index 575039723d56e..50fceeb6c34dc 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp @@ -56,7 +56,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.row.row.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -87,7 +87,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.col.col.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -118,7 +118,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.row.row.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -149,7 +149,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.col.col.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -180,7 +180,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.row.row.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -211,7 +211,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.col.col.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp index 8ada375fff395..e8b77588237d1 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp @@ -88,7 +88,7 @@ int main() { get_wi_data(sg, sub_b)[i] = round_to_tf32(get_wi_data(sg, sub_b)[i]); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 {{.*}} joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -137,7 +137,7 @@ int main() { get_wi_data(sg, sub_b)[i] = round_to_tf32(get_wi_data(sg, sub_b)[i]); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp index 69bc136e79776..256bf847645e8 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp @@ -56,7 +56,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.row.row.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -87,7 +87,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.col.col.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -118,7 +118,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.row.row.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -149,7 +149,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.col.col.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -180,7 +180,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.row.row.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -211,7 +211,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.col.col.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/matrix/matrix_load_store_as.cpp b/sycl/test/check_device_code/matrix/matrix_load_store_as.cpp index e5935b8b3af47..4c3c97c77edd8 100644 --- a/sycl/test/check_device_code/matrix/matrix_load_store_as.cpp +++ b/sycl/test/check_device_code/matrix/matrix_load_store_as.cpp @@ -29,8 +29,9 @@ int main(void) { joint_matrix tA; - joint_matrix + joint_matrix< + sub_group, unsigned short, use::b, 16, 16, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> tB; joint_matrix tC; @@ -49,7 +50,7 @@ int main(void) { // B should load from global address space // CHECK: %{{.*}} = tail call spir_func noundef target("spirv.JointMatrixINTEL", i16, 16, 16, 2, 3, 1) @_Z[[#]]__spirv_JointMatrixLoadINTEL{{.*}}(ptr addrspace(1) noundef %{{.*}}, i64 noundef 32, i32 noundef 2, i32 noundef 3, i32 noundef 0) #{{.*}} joint_matrix_load(sg, tB, pB, 32); - tC = joint_matrix_mad(sg, tA, tB, tC); + joint_matrix_mad(sg, tA, tB, tC, tC); // C should store to global address space // CHECK: tail call spir_func void @_Z[[#]]__spirv_JointMatrixStoreINTEL{{.*}}(ptr addrspace(1) noundef %{{.*}}, target("spirv.JointMatrixINTEL", float, 8, 16, 3, 3, 2) noundef %{{.*}}, i64 noundef 16, i32 noundef 0, i32 noundef 3, i32 noundef 0) #{{.*}} joint_matrix_store(sg, tC, pC, 16, layout::row_major); diff --git a/sycl/test/check_device_code/matrix/matrix_load_store_as_legacy.cpp b/sycl/test/check_device_code/matrix/matrix_load_store_as_legacy.cpp index bb18a21bc1002..022ac65612a28 100644 --- a/sycl/test/check_device_code/matrix/matrix_load_store_as_legacy.cpp +++ b/sycl/test/check_device_code/matrix/matrix_load_store_as_legacy.cpp @@ -47,7 +47,7 @@ int main(void) { // B should load from global address space // CHECK: %{{.*}} = tail call spir_func noundef target("spirv.JointMatrixINTEL", i16, 16, 16, 3, 3) @_Z[[#]]__spirv_JointMatrixLoadINTEL{{.*}}(ptr addrspace(1) noundef %{{.*}}, i64 noundef 32, i32 noundef [[#]], i32 noundef 3, i32 noundef 0) #{{.*}} joint_matrix_load(sg, tB, pB, 32, matrix_layout::packed_b); - tC = joint_matrix_mad(sg, tA, tB, tC); + joint_matrix_mad(sg, tA, tB, tC, tC); // C should store to global address space // CHECK: tail call spir_func void @_Z[[#]]__spirv_JointMatrixStoreINTEL{{.*}}(ptr addrspace(1) noundef %{{.*}}, target("spirv.JointMatrixINTEL", float, 8, 16, 0, 3) noundef %{{.*}}, i64 noundef 16, i32 noundef 0, i32 noundef 3, i32 noundef 0) #{{.*}} joint_matrix_store(sg, tC, pC, 16, matrix_layout::row_major); diff --git a/sycl/test/matrix/legacy/matrix-bf16-test-SG-16.cpp b/sycl/test/matrix/legacy/matrix-bf16-test-SG-16.cpp index 391a9be2197c6..2a4ac877a3864 100644 --- a/sycl/test/matrix/legacy/matrix-bf16-test-SG-16.cpp +++ b/sycl/test/matrix/legacy/matrix-bf16-test-SG-16.cpp @@ -89,7 +89,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/legacy/matrix-bf16-test.cpp b/sycl/test/matrix/legacy/matrix-bf16-test.cpp index 6c6bfc1066f01..f58fe277d0073 100644 --- a/sycl/test/matrix/legacy/matrix-bf16-test.cpp +++ b/sycl/test/matrix/legacy/matrix-bf16-test.cpp @@ -88,7 +88,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/legacy/matrix-bfloat16-test.cpp b/sycl/test/matrix/legacy/matrix-bfloat16-test.cpp index 022e69f9b75a2..3271b27bc466d 100644 --- a/sycl/test/matrix/legacy/matrix-bfloat16-test.cpp +++ b/sycl/test/matrix/legacy/matrix-bfloat16-test.cpp @@ -87,7 +87,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/legacy/matrix-elemwise-ops.cpp b/sycl/test/matrix/legacy/matrix-elemwise-ops.cpp index feddb05148c4e..5285b8eb9aa2b 100644 --- a/sycl/test/matrix/legacy/matrix-elemwise-ops.cpp +++ b/sycl/test/matrix/legacy/matrix-elemwise-ops.cpp @@ -88,7 +88,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); } auto wi_data_c = sub_c.get_wi_data(); for (int i = 0; i < wi_data_c.length(); i++) { diff --git a/sycl/test/matrix/legacy/matrix-int8-test-SG-16.cpp b/sycl/test/matrix/legacy/matrix-int8-test-SG-16.cpp index 335529ad3120a..47af048171265 100644 --- a/sycl/test/matrix/legacy/matrix-int8-test-SG-16.cpp +++ b/sycl/test/matrix/legacy/matrix-int8-test-SG-16.cpp @@ -88,7 +88,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/legacy/matrix-int8-test.cpp b/sycl/test/matrix/legacy/matrix-int8-test.cpp index 77c57b4ef711e..614faf8defe5a 100644 --- a/sycl/test/matrix/legacy/matrix-int8-test.cpp +++ b/sycl/test/matrix/legacy/matrix-int8-test.cpp @@ -88,7 +88,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp b/sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp index b141f2176971d..769efbdb0d959 100644 --- a/sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp +++ b/sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp @@ -154,8 +154,9 @@ void matrix_sum_cols(queue q, big_matrix &B, nd_range<2> &r) { sub_group sg = spmd_item.get_sub_group(); // TK = 32, TN = 16 - joint_matrix + joint_matrix< + sub_group, int8_t, use::b, TK, TN, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b; joint_matrix_load( @@ -166,8 +167,7 @@ void matrix_sum_cols(queue q, big_matrix &B, nd_range<2> &r) { int32_t sum_local_cols[N] = {0}; // 4 local cols, N total // sub_b has 32x16 elements, 32 elements per WI, 4 per WI per row - auto wiData = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); + auto wiData = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); size_t global_index; // Index into the result array that holds the sums. diff --git a/sycl/test/matrix/matrix-bfloat16-test.cpp b/sycl/test/matrix/matrix-bfloat16-test.cpp index 2e0e309081464..eae42973d45c4 100644 --- a/sycl/test/matrix/matrix-bfloat16-test.cpp +++ b/sycl/test/matrix/matrix-bfloat16-test.cpp @@ -68,7 +68,8 @@ void matrix_multiply(big_matrix &C, // the packed_b layout. By default, the layout is row_major and size // is (TK, TN). joint_matrix + sycl::ext::oneapi::experimental::matrix::layout:: + ext_intel_packed> sub_b; joint_matrix sub_c; @@ -89,7 +90,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/matrix-elemwise-ops.cpp b/sycl/test/matrix/matrix-elemwise-ops.cpp index 3205e4c346ba6..861727c3fe92d 100644 --- a/sycl/test/matrix/matrix-elemwise-ops.cpp +++ b/sycl/test/matrix/matrix-elemwise-ops.cpp @@ -69,7 +69,8 @@ void matrix_multiply(big_matrix &C, // the packed_b layout. By default, the layout is row_major and size // is (TK, TN). joint_matrix + sycl::ext::oneapi::experimental::matrix::layout:: + ext_intel_packed> sub_b; joint_matrix sub_c; @@ -93,10 +94,9 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } - auto wi_data_c = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); + auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, sub_c); for (int i = 0; i < wi_data_c.length(); i++) { wi_data_c[i] *= 2; } diff --git a/sycl/test/matrix/matrix-int8-test.cpp b/sycl/test/matrix/matrix-int8-test.cpp index f8dcc26ab1b17..959f5b2b30871 100644 --- a/sycl/test/matrix/matrix-int8-test.cpp +++ b/sycl/test/matrix/matrix-int8-test.cpp @@ -74,7 +74,8 @@ void matrix_multiply(big_matrix &C, // the packed_b layout. By default, the layout is row_major and size // is (TK, TN). joint_matrix + sycl::ext::oneapi::experimental::matrix::layout:: + ext_intel_packed> sub_b; joint_matrix sub_c; @@ -94,7 +95,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/matrix-tf32-test.cpp b/sycl/test/matrix/matrix-tf32-test.cpp index d6affb4067003..64852b4a890cc 100644 --- a/sycl/test/matrix/matrix-tf32-test.cpp +++ b/sycl/test/matrix/matrix-tf32-test.cpp @@ -87,12 +87,11 @@ void matrix_multiply(big_matrix &C, // function will work on truncated floats. joint_matrix_apply(sg, sub_a, [=](float x) { x = round_to_tf32(x); }); - auto wi_data_b = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); + auto wi_data_b = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); for (int i = 0; i < wi_data_b.length(); i++) { wi_data_b[i] = round_to_tf32(wi_data_b[i]); } - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/query-use.cpp b/sycl/test/matrix/query-use.cpp index 9afc8e1173043..96d2fd3c0c26a 100644 --- a/sycl/test/matrix/query-use.cpp +++ b/sycl/test/matrix/query-use.cpp @@ -64,7 +64,8 @@ void query_amx() { sub_group sg = spmd_item.get_sub_group(); myparams2::joint_matrix_a sub_a1; myparams2::joint_matrix_b< - sub_group, sycl::ext::intel::experimental::matrix::layout::packed> + sub_group, + sycl::ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b1; myparams2::joint_matrix_accumulator sub_c1; @@ -144,7 +145,8 @@ void query_xmx8() { sub_group sg = spmd_item.get_sub_group(); myparams2::joint_matrix_a sub_a1; myparams2::joint_matrix_b< - sub_group, sycl::ext::intel::experimental::matrix::layout::packed> + sub_group, + sycl::ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b1; myparams2::joint_matrix_accumulator sub_c1; From 5fbb285ad0dd6727b58d2865632d53cc829db16d Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Tue, 19 Sep 2023 12:11:32 +0800 Subject: [PATCH 02/23] clang-format --- .../sycl/ext/oneapi/matrix/matrix-intel.hpp | 8 +- .../sycl/ext/oneapi/matrix/matrix-unified.hpp | 6 +- .../Matrix/Legacy/element_wise_ops_impl.hpp | 97 ++++++++++--------- .../Legacy/elemwise_irreg_size_ops_bf16.cpp | 3 +- .../Matrix/Legacy/joint_matrix_bf16_impl.hpp | 3 +- .../joint_matrix_bfloat16_32x64_impl.hpp | 3 +- ...trix_bfloat16_colmajorA_colmajorB_impl.hpp | 3 +- .../Legacy/joint_matrix_bfloat16_impl.hpp | 3 +- ...trix_bfloat16_rowmajorA_rowmajorB_impl.hpp | 3 +- .../Matrix/Legacy/joint_matrix_half_impl.hpp | 85 ++++++++-------- ...t_matrix_int8_colmajorA_colmajorB_impl.hpp | 3 +- .../Legacy/joint_matrix_int8_vnni_impl.hpp | 3 +- .../Legacy/joint_matrix_query_default.cpp | 3 +- .../Legacy/joint_matrix_ss_int8_impl.hpp | 77 +++++++-------- .../Legacy/joint_matrix_su_int8_impl.hpp | 89 ++++++++--------- .../Legacy/joint_matrix_us_int8_impl.hpp | 3 +- .../Legacy/joint_matrix_uu_int8_impl.hpp | 89 ++++++++--------- .../Matrix/element_wise_all_ops_half_impl.hpp | 15 +-- .../Matrix/element_wise_all_ops_impl.hpp | 6 +- .../Matrix/element_wise_all_ops_int8_impl.hpp | 15 +-- .../Matrix/element_wise_all_ops_tf32_impl.hpp | 15 +-- .../Matrix/element_wise_all_sizes_impl.hpp | 47 +++++---- .../Matrix/joint_matrix_down_convert_impl.hpp | 9 +- .../Matrix/joint_matrix_out_bounds_impl.hpp | 3 +- .../matrix/matrix-nvptx-bfloat16-test.cpp | 18 ++-- .../cuda/matrix/matrix-nvptx-double-test.cpp | 6 +- .../matrix/matrix-nvptx-half-float-test.cpp | 18 ++-- .../matrix/matrix-nvptx-half-half-test.cpp | 18 ++-- .../cuda/matrix/matrix-nvptx-int8-test.cpp | 18 ++-- .../cuda/matrix/matrix-nvptx-tf32-test.cpp | 6 +- .../cuda/matrix/matrix-nvptx-uint8-test.cpp | 18 ++-- .../matrix/legacy/matrix-bf16-test-SG-16.cpp | 3 +- sycl/test/matrix/legacy/matrix-bf16-test.cpp | 3 +- .../matrix/legacy/matrix-bfloat16-test.cpp | 3 +- .../matrix/legacy/matrix-elemwise-ops.cpp | 3 +- .../matrix/legacy/matrix-int8-test-SG-16.cpp | 3 +- sycl/test/matrix/legacy/matrix-int8-test.cpp | 3 +- 37 files changed, 371 insertions(+), 340 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp index eac624a9c3360..c3d4e973ef643 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp @@ -187,7 +187,7 @@ class wi_element { #if __SYCL_DEVICE_ONLY__ #define OP(op) \ - template wi_element &operator op##=(const T2 &rhs) { \ + template wi_element &operator op##=(const T2 & rhs) { \ M.spvm = __spirv_VectorInsertDynamic( \ M.spvm, \ static_cast( \ @@ -202,7 +202,7 @@ class wi_element { } #else // __SYCL_DEVICE_ONLY__ #define OP(op) \ - template wi_element &operator op##=(const T2 &rhs) { \ + template wi_element &operator op##=(const T2 & rhs) { \ (void)rhs; \ throw runtime_error("joint matrix is not supported on host device.", \ PI_ERROR_INVALID_DEVICE); \ @@ -306,7 +306,7 @@ class wi_element -void joint_matrix_copy(Group sg, - joint_matrix &src, - joint_matrix &dst) { +void joint_matrix_copy( + Group sg, joint_matrix &src, + joint_matrix &dst) { #if defined(__SYCL_DEVICE_ONLY__) #if defined(__NVPTX__) std::ignore = sg; diff --git a/sycl/test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp b/sycl/test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp index 7179b82855f50..dcc5764db1fcb 100644 --- a/sycl/test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp @@ -38,56 +38,57 @@ void matrix_multiply(big_matrix &C, cgh.parallel_for( nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), - [accA, accB, accC, M, N, - K](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { - // The submatrix API has to be accessed by all the workitems in a - // subgroup these functions will be called once by the subgroup no - // code divergence between the workitems - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx - spmd_item.get_local_id(0); - const auto sg_starty = global_idy - spmd_item.get_local_id(1); + [accA, accB, accC, M, N, K](nd_item<2> spmd_item) + [[intel::reqd_sub_group_size(SG_SZ)]] { + // The submatrix API has to be accessed by all the workitems in a + // subgroup these functions will be called once by the subgroup + // no code divergence between the workitems + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); - sycl::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); - // For B, since current implementation does not support non-packed - // layout, users need to specify the updated VNNI sizes along with - // the packed_b layout. By default, the layout is row_major and size - // is (TK, TN). - joint_matrix sub_b(sg); - joint_matrix sub_c(sg); + sycl::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + // For B, since current implementation does not support + // non-packed layout, users need to specify the updated VNNI + // sizes along with the packed_b layout. By default, the layout + // is row_major and size is (TK, TN). + joint_matrix sub_b(sg); + joint_matrix sub_c(sg); - // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 - // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 - joint_matrix_load( - sg, sub_c, - accC.template get_multi_ptr() + - (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); - for (int k = 0; k < K / TK; k += 1) { - joint_matrix_load( - sg, sub_a, - accA.template get_multi_ptr() + - (sg_startx * TM) * K + k * TK, - K, matrix_layout::row_major); - // Assuming B data is already in VNNI format. - joint_matrix_load( - sg, sub_b, - accB.template get_multi_ptr() + - (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, - N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); - } - auto wi_slice_c = sub_c.get_wi_data(); - for (int i = 0; i < wi_slice_c.length(); i++) { - wi_slice_c[i] *= 2; - } - joint_matrix_store( - sg, sub_c, - accC.template get_multi_ptr() + - (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); - }); // parallel for + // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 + // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + for (int k = 0; k < K / TK; k += 1) { + joint_matrix_load( + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, + K, matrix_layout::row_major); + // Assuming B data is already in VNNI format. + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, + N * 4, matrix_layout::packed_b); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, + sub_b, sub_c); + } + auto wi_slice_c = sub_c.get_wi_data(); + for (int i = 0; i < wi_slice_c.length(); i++) { + wi_slice_c[i] *= 2; + } + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/Legacy/elemwise_irreg_size_ops_bf16.cpp b/sycl/test-e2e/Matrix/Legacy/elemwise_irreg_size_ops_bf16.cpp index d041cd6050d73..d0ee5869b1d54 100644 --- a/sycl/test-e2e/Matrix/Legacy/elemwise_irreg_size_ops_bf16.cpp +++ b/sycl/test-e2e/Matrix/Legacy/elemwise_irreg_size_ops_bf16.cpp @@ -100,7 +100,8 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k) * (N) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); } auto wi_slice_c = sub_c.get_wi_data(); for (int i = 0; i < wi_slice_c.length(); i++) { diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bf16_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bf16_impl.hpp index 3ba9ae346e070..060830074d1fe 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bf16_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bf16_impl.hpp @@ -76,7 +76,8 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k) * (N) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_32x64_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_32x64_impl.hpp index 8835bce054171..b8ed30ae495b9 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_32x64_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_32x64_impl.hpp @@ -68,7 +68,8 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp index 48c4a894ab385..5738b3a109a9c 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp @@ -62,7 +62,8 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (sg_starty / SG_SZ * TN) * K + k * TK, K, matrix_layout::col_major); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_impl.hpp index db16f05673321..9358ffb4168cd 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_impl.hpp @@ -68,7 +68,8 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp index 5676cd849e1b5..1a567c9147097 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp @@ -62,7 +62,8 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK) * (N) + sg_starty / SG_SZ * TN, N, matrix_layout::row_major); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_half_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_half_impl.hpp index 658985c10ab5c..86439c83ff840 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_half_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_half_impl.hpp @@ -37,50 +37,51 @@ void matrix_multiply(big_matrix &C, cgh.parallel_for( nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, SG_SZ}), - [accA, accB, accC, M, N, - K](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { - // The submatrix API has to be accessed by all the workitems in a - // subgroup these functions will be called once by the subgroup no - // code divergence between the workitems - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx - spmd_item.get_local_id(0); - const auto sg_starty = global_idy - spmd_item.get_local_id(1); + [accA, accB, accC, M, N, K](nd_item<2> spmd_item) + [[intel::reqd_sub_group_size(SG_SZ)]] { + // The submatrix API has to be accessed by all the workitems in a + // subgroup these functions will be called once by the subgroup + // no code divergence between the workitems + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); - sycl::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); - // For B, since current implementation does not support non-packed - // layout, users need to specify the updated VNNI sizes along with - // the packed_b layout. By default, the layout is row_major and size - // is (TK, TN). - joint_matrix sub_b(sg); - joint_matrix sub_c(sg); + sycl::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + // For B, since current implementation does not support + // non-packed layout, users need to specify the updated VNNI + // sizes along with the packed_b layout. By default, the layout + // is row_major and size is (TK, TN). + joint_matrix sub_b(sg); + joint_matrix sub_c(sg); - joint_matrix_load( - sg, sub_c, - accC.template get_multi_ptr() + - (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); - for (int k = 0; k < K / TK; k += 1) { - joint_matrix_load( - sg, sub_a, - accA.template get_multi_ptr() + - (sg_startx * TM) * K + k * TK, - K, matrix_layout::row_major); - // Assuming B data is already in VNNI format. - joint_matrix_load( - sg, sub_b, - accB.template get_multi_ptr() + - (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, - N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); - } - joint_matrix_store( - sg, sub_c, - accC.template get_multi_ptr() + - (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); - }); // parallel for + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + for (int k = 0; k < K / TK; k += 1) { + joint_matrix_load( + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, + K, matrix_layout::row_major); + // Assuming B data is already in VNNI format. + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, + N * 2, matrix_layout::packed_b); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, + sub_b, sub_c); + } + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_colmajorA_colmajorB_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_colmajorA_colmajorB_impl.hpp index 85bffe8957e77..a8c97d97aced3 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_colmajorA_colmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_colmajorA_colmajorB_impl.hpp @@ -72,7 +72,8 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (sg_starty / SG_SZ * TN) * K + k * TK, K, matrix_layout::col_major); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_vnni_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_vnni_impl.hpp index 2cb8196a49306..9a9d895f4771c 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_vnni_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_vnni_impl.hpp @@ -64,7 +64,8 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK) * N + sg_starty / SG_SZ * TN, N, matrix_layout::row_major); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, + sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_query_default.cpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_query_default.cpp index 6b7878b5cce7d..fb3dceee3f361 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_query_default.cpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_query_default.cpp @@ -97,7 +97,8 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp index ef594dc6bc3a8..1598f8b78b3a2 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp @@ -38,46 +38,47 @@ void matrix_multiply(big_matrix &C, cgh.parallel_for( nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), - [accA, accB, accC, M, N, - K](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { - // The submatrix API has to be accessed by all the workitems in a - // subgroup these functions will be called once by the subgroup no - // code divergence between the workitems - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx - spmd_item.get_local_id(0); - const auto sg_starty = global_idy - spmd_item.get_local_id(1); + [accA, accB, accC, M, N, K](nd_item<2> spmd_item) + [[intel::reqd_sub_group_size(SG_SZ)]] { + // The submatrix API has to be accessed by all the workitems in a + // subgroup these functions will be called once by the subgroup + // no code divergence between the workitems + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); - sycl::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); - // For B, since current implementation does not support non-packed - // layout, users need to specify the updated VNNI sizes along with - // the packed_b layout. By default, the layout is row_major and size - // is (TK, TN). - joint_matrix sub_b(sg); - joint_matrix sub_c(sg); + sycl::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + // For B, since current implementation does not support + // non-packed layout, users need to specify the updated VNNI + // sizes along with the packed_b layout. By default, the layout + // is row_major and size is (TK, TN). + joint_matrix sub_b(sg); + joint_matrix sub_c(sg); - joint_matrix_fill(sg, sub_c, 0); - for (int k = 0; k < K / TK; k += 1) { - joint_matrix_load( - sg, sub_a, - accA.template get_multi_ptr() + - (sg_startx * TM) * K + k * TK, - K, matrix_layout::row_major); - // Assuming B data is already in VNNI format. - joint_matrix_load( - sg, sub_b, - accB.template get_multi_ptr() + - (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, - N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); - } - joint_matrix_store( - sg, sub_c, - accC.template get_multi_ptr() + - (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); - }); // parallel for + joint_matrix_fill(sg, sub_c, 0); + for (int k = 0; k < K / TK; k += 1) { + joint_matrix_load( + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, + K, matrix_layout::row_major); + // Assuming B data is already in VNNI format. + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, + N * 4, matrix_layout::packed_b); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, + sub_b, sub_c); + } + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_su_int8_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_su_int8_impl.hpp index 0eb5bc2bc8d58..34320622f006c 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_su_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_su_int8_impl.hpp @@ -38,52 +38,53 @@ void matrix_multiply(big_matrix &C, cgh.parallel_for( nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), - [accA, accB, accC, M, N, - K](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { - // The submatrix API has to be accessed by all the workitems in a - // subgroup these functions will be called once by the subgroup no - // code divergence between the workitems - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx - spmd_item.get_local_id(0); - const auto sg_starty = global_idy - spmd_item.get_local_id(1); + [accA, accB, accC, M, N, K](nd_item<2> spmd_item) + [[intel::reqd_sub_group_size(SG_SZ)]] { + // The submatrix API has to be accessed by all the workitems in a + // subgroup these functions will be called once by the subgroup + // no code divergence between the workitems + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); - sycl::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); - // For B, since current implementation does not support non-packed - // layout, users need to specify the updated VNNI sizes along with - // the packed_b layout. By default, the layout is row_major and size - // is (TK, TN). - joint_matrix sub_b(sg); - joint_matrix sub_c(sg); + sycl::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + // For B, since current implementation does not support + // non-packed layout, users need to specify the updated VNNI + // sizes along with the packed_b layout. By default, the layout + // is row_major and size is (TK, TN). + joint_matrix sub_b(sg); + joint_matrix sub_c(sg); - // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 - // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 - joint_matrix_load( - sg, sub_c, - accC.template get_multi_ptr() + - (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); - for (int k = 0; k < K / TK; k += 1) { - joint_matrix_load( - sg, sub_a, - accA.template get_multi_ptr() + - (sg_startx * TM) * K + k * TK, - K, matrix_layout::row_major); - // Assuming B data is already in VNNI format. - joint_matrix_load( - sg, sub_b, - accB.template get_multi_ptr() + - (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, - N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); - } - joint_matrix_store( - sg, sub_c, - accC.template get_multi_ptr() + - (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); - }); // parallel for + // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 + // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + for (int k = 0; k < K / TK; k += 1) { + joint_matrix_load( + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, + K, matrix_layout::row_major); + // Assuming B data is already in VNNI format. + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, + N * 4, matrix_layout::packed_b); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, + sub_b, sub_c); + } + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_us_int8_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_us_int8_impl.hpp index 3d0067cbb7b36..bc0ed40202116 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_us_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_us_int8_impl.hpp @@ -78,7 +78,8 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp index be4eee8452c3c..164892179f674 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp @@ -38,52 +38,53 @@ void matrix_multiply(big_matrix &C, cgh.parallel_for( nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), - [accA, accB, accC, M, N, - K](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { - // The submatrix API has to be accessed by all the workitems in a - // subgroup these functions will be called once by the subgroup no - // code divergence between the workitems - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx - spmd_item.get_local_id(0); - const auto sg_starty = global_idy - spmd_item.get_local_id(1); + [accA, accB, accC, M, N, K](nd_item<2> spmd_item) + [[intel::reqd_sub_group_size(SG_SZ)]] { + // The submatrix API has to be accessed by all the workitems in a + // subgroup these functions will be called once by the subgroup + // no code divergence between the workitems + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); - sycl::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); - // For B, since current implementation does not support non-packed - // layout, users need to specify the updated VNNI sizes along with - // the packed_b layout. By default, the layout is row_major and size - // is (TK, TN). - joint_matrix sub_b(sg); - joint_matrix sub_c(sg); + sycl::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + // For B, since current implementation does not support + // non-packed layout, users need to specify the updated VNNI + // sizes along with the packed_b layout. By default, the layout + // is row_major and size is (TK, TN). + joint_matrix sub_b(sg); + joint_matrix sub_c(sg); - // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 - // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 - joint_matrix_load( - sg, sub_c, - accC.template get_multi_ptr() + - (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); - for (int k = 0; k < K / TK; k += 1) { - joint_matrix_load( - sg, sub_a, - accA.template get_multi_ptr() + - (sg_startx * TM) * K + k * TK, - K, matrix_layout::row_major); - // Assuming B data is already in VNNI format. - joint_matrix_load( - sg, sub_b, - accB.template get_multi_ptr() + - (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, - N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); - } - joint_matrix_store( - sg, sub_c, - accC.template get_multi_ptr() + - (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); - }); // parallel for + // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 + // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + for (int k = 0; k < K / TK; k += 1) { + joint_matrix_load( + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, + K, matrix_layout::row_major); + // Assuming B data is already in VNNI format. + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, + N * 4, matrix_layout::packed_b); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, + sub_b, sub_c); + } + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_half_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_half_impl.hpp index 063acee4c2ffe..df786d73a78ea 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_half_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_half_impl.hpp @@ -41,8 +41,7 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = - sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] + static_cast(2); } @@ -76,8 +75,7 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = - sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] - static_cast(2); } @@ -111,8 +109,7 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = - sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] * static_cast(3.0); } @@ -146,8 +143,7 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 4); - auto wi_slice_a = - sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] / static_cast(2.0); } @@ -181,8 +177,7 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = - sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { if (wi_slice_a[i]) { if (wi_slice_a[i] > static_cast(2.0) || diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp index 25d77269b03a4..447dff879cb01 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp @@ -63,8 +63,7 @@ void verify_op_a(const T l, const T r, const float ref, OP op) { layout::row_major> sub_mat; joint_matrix_fill(sg, sub_mat, l); - auto wi_slice = - sycl::ext::oneapi::detail::get_wi_data(sg, sub_mat); + auto wi_slice = sycl::ext::oneapi::detail::get_wi_data(sg, sub_mat); for (int i = 0; i < wi_slice.length(); i++) { wi_slice[i] = op(wi_slice[i], r); } @@ -104,8 +103,7 @@ void verify_op_c(const T l, const T r, const float ref, OP op) { joint_matrix sub_mat; joint_matrix_fill(sg, sub_mat, l); - auto wi_slice = - sycl::ext::oneapi::detail::get_wi_data(sg, sub_mat); + auto wi_slice = sycl::ext::oneapi::detail::get_wi_data(sg, sub_mat); for (int i = 0; i < wi_slice.length(); i++) { wi_slice[i] = op(wi_slice[i], r); } diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_int8_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_int8_impl.hpp index 5d1668b753baa..c18e371711824 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_int8_impl.hpp @@ -40,8 +40,7 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = - sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] + 2; } @@ -75,8 +74,7 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = - sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] - 2; } @@ -110,8 +108,7 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = - sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] * 3; } @@ -145,8 +142,7 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 4); - auto wi_slice_a = - sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] / 2; } @@ -180,8 +176,7 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = - sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { if (wi_slice_a[i]) { if (wi_slice_a[i] > 2 || wi_slice_a[i] >= 2 || diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_tf32_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_tf32_impl.hpp index 96fcfa975b408..e73056639eb74 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_tf32_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_tf32_impl.hpp @@ -41,8 +41,7 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, round_to_tf32(5.0)); - auto wi_slice_a = - ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] + 2; } @@ -77,8 +76,7 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, round_to_tf32(5.0)); - auto wi_slice_a = - ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] - round_to_tf32(2); } @@ -111,8 +109,7 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, joint_matrix sub_a; joint_matrix_fill(sg, sub_a, round_to_tf32(5.0)); - auto wi_slice_a = - ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] * round_to_tf32(3.0); } @@ -146,8 +143,7 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, round_to_tf32(4.0)); - auto wi_slice_a = - ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] / round_to_tf32(2.0); } @@ -180,8 +176,7 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, round_to_tf32(5.0)); - auto wi_slice_a = - ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { if (wi_slice_a[i]) { if (wi_slice_a[i] > 2.0 || wi_slice_a[i] >= 2.0 || diff --git a/sycl/test-e2e/Matrix/element_wise_all_sizes_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_sizes_impl.hpp index b18ca01193974..f1aacbef1230f 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_sizes_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_sizes_impl.hpp @@ -49,30 +49,29 @@ void matrix_verify_add(const T1 val1, const T1 val2, const T1 result) { q.submit([&](handler &cgh) { sycl::accessor accA{bufA, cgh, sycl::read_write}; - cgh.parallel_for(r, [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size( - SG_SZ)]] { - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx - spmd_item.get_local_id(0); - const auto sg_starty = global_idy - spmd_item.get_local_id(1); - - sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a; - - joint_matrix_fill(sg, sub_a, val1); - - auto wi_slice_a = - sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] = wi_slice_a[i] + val2; - } - - ext::intel::experimental::matrix::joint_matrix_store( - sg, sub_a, - accA.template get_multi_ptr() + - (sg_startx * TM) * K + sg_starty / SG_SZ * TK, - K); - }); // parallel for + cgh.parallel_for( + r, [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a; + + joint_matrix_fill(sg, sub_a, val1); + + auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + for (int i = 0; i < wi_slice_a.length(); i++) { + wi_slice_a[i] = wi_slice_a[i] + val2; + } + + ext::intel::experimental::matrix::joint_matrix_store( + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + sg_starty / SG_SZ * TK, + K); + }); // parallel for }).wait(); assert_ops_ref(bufA.get_host_access(), result); } diff --git a/sycl/test-e2e/Matrix/joint_matrix_down_convert_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_down_convert_impl.hpp index 6972e3854c8e8..5dd21cfe5340b 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_down_convert_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_down_convert_impl.hpp @@ -49,14 +49,7 @@ void matrix_copy(big_matrix &C, big_matrix &A) { (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N, layout::row_major); // This will be replaced by joint_matrix_copy API - // joint_matrix_copy(sg, sub_c, sub_ac); - auto wi_slice_c = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_c.length(); i++) { - wi_slice_a[i] = (bfloat16)wi_slice_c[i]; - } + joint_matrix_copy(sg, sub_c, sub_a); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + diff --git a/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp index 46c49f6576e44..63eb4af659170 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp @@ -58,7 +58,8 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q, unsigned int vnniFactor) { joint_matrix_load_checked( sg, sub_b, pB + k * N + sg_starty / SG_SZ * TN * vnniFactor, N * vnniFactor, K / vnniFactor, N * vnniFactor); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); } // bounds-checked store where width and height are added joint_matrix_store_checked( diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp index 4790a1dcb1bf6..336db7c2f00e9 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp @@ -57,7 +57,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -88,7 +89,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -119,7 +121,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -150,7 +153,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -181,7 +185,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -212,7 +217,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp index 30704f8869778..c8fcccf2d015e 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp @@ -67,7 +67,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), N); //CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.row.row.f64(double {{.*}}, double {{.*}}, double {{.*}}, double {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.row.stride.f64.p1(ptr addrspace(1) %{{.*}}, double {{.*}}, double {{.*}}, i32 8) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -98,7 +99,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), K); //CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.col.col.f64(double {{.*}}, double {{.*}}, double {{.*}}, double {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64.p1(ptr addrspace(1) %{{.*}}, double {{.*}}, double {{.*}}, i32 8) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp index 3331c66d302b6..b1e4015460f3f 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp @@ -56,7 +56,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.row.row.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -87,7 +88,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.col.col.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -118,7 +120,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.row.row.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -149,7 +152,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.col.col.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -180,7 +184,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.row.row.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -211,7 +216,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.col.col.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp index e4250952fa0c2..0cbc24560c589 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp @@ -56,7 +56,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -87,7 +88,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -118,7 +120,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -149,7 +152,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -180,7 +184,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -211,7 +216,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp index 50fceeb6c34dc..0743dc0a7ffc1 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp @@ -56,7 +56,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.row.row.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -87,7 +88,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.col.col.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -118,7 +120,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.row.row.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -149,7 +152,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.col.col.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -180,7 +184,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.row.row.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -211,7 +216,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.col.col.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp index e8b77588237d1..f6a10c56cd866 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp @@ -88,7 +88,8 @@ int main() { get_wi_data(sg, sub_b)[i] = round_to_tf32(get_wi_data(sg, sub_b)[i]); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 {{.*}} joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -137,7 +138,8 @@ int main() { get_wi_data(sg, sub_b)[i] = round_to_tf32(get_wi_data(sg, sub_b)[i]); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp index 256bf847645e8..448dc86f3321f 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp @@ -56,7 +56,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.row.row.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -87,7 +88,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.col.col.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -118,7 +120,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.row.row.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -149,7 +152,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.col.col.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -180,7 +184,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.row.row.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -211,7 +216,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.col.col.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/matrix/legacy/matrix-bf16-test-SG-16.cpp b/sycl/test/matrix/legacy/matrix-bf16-test-SG-16.cpp index 2a4ac877a3864..b45e32786acf6 100644 --- a/sycl/test/matrix/legacy/matrix-bf16-test-SG-16.cpp +++ b/sycl/test/matrix/legacy/matrix-bf16-test-SG-16.cpp @@ -89,7 +89,8 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/legacy/matrix-bf16-test.cpp b/sycl/test/matrix/legacy/matrix-bf16-test.cpp index f58fe277d0073..f2bc1f3c5618e 100644 --- a/sycl/test/matrix/legacy/matrix-bf16-test.cpp +++ b/sycl/test/matrix/legacy/matrix-bf16-test.cpp @@ -88,7 +88,8 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/legacy/matrix-bfloat16-test.cpp b/sycl/test/matrix/legacy/matrix-bfloat16-test.cpp index 3271b27bc466d..d41d9152cd157 100644 --- a/sycl/test/matrix/legacy/matrix-bfloat16-test.cpp +++ b/sycl/test/matrix/legacy/matrix-bfloat16-test.cpp @@ -87,7 +87,8 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/legacy/matrix-elemwise-ops.cpp b/sycl/test/matrix/legacy/matrix-elemwise-ops.cpp index 5285b8eb9aa2b..a13b79deea9f2 100644 --- a/sycl/test/matrix/legacy/matrix-elemwise-ops.cpp +++ b/sycl/test/matrix/legacy/matrix-elemwise-ops.cpp @@ -88,7 +88,8 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); } auto wi_data_c = sub_c.get_wi_data(); for (int i = 0; i < wi_data_c.length(); i++) { diff --git a/sycl/test/matrix/legacy/matrix-int8-test-SG-16.cpp b/sycl/test/matrix/legacy/matrix-int8-test-SG-16.cpp index 47af048171265..09aa3349c38d1 100644 --- a/sycl/test/matrix/legacy/matrix-int8-test-SG-16.cpp +++ b/sycl/test/matrix/legacy/matrix-int8-test-SG-16.cpp @@ -88,7 +88,8 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/legacy/matrix-int8-test.cpp b/sycl/test/matrix/legacy/matrix-int8-test.cpp index 614faf8defe5a..457f5b4747d54 100644 --- a/sycl/test/matrix/legacy/matrix-int8-test.cpp +++ b/sycl/test/matrix/legacy/matrix-int8-test.cpp @@ -88,7 +88,8 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); } joint_matrix_store( sg, sub_c, From bf6cd56fe117de6cb639545870b3fb8d0c8a361f Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Tue, 19 Sep 2023 12:25:08 +0800 Subject: [PATCH 03/23] fix typo: dest->dst --- sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp index 55a163f4dc38e..90f267bff19c4 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -428,7 +428,7 @@ void joint_matrix_copy( #if defined(__NVPTX__) std::ignore = sg; for (int i = 0; i < src.cuda_impl.wi_marray.size(); i++) { - dest.cuda_impl.wi_marray[i] = src.cuda_impl.wi_marray[i]; + dst.cuda_impl.wi_marray[i] = src.cuda_impl.wi_marray[i]; } #else using storage_element_type = From b399041060b065afa32904f1181af9aef570ed57 Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Tue, 19 Sep 2023 13:53:35 +0800 Subject: [PATCH 04/23] fix testcase --- .../Matrix/Legacy/element_wise_ops_impl.hpp | 3 +-- .../Legacy/elemwise_irreg_size_ops_bf16.cpp | 3 +-- .../Matrix/Legacy/joint_matrix_bf16_impl.hpp | 3 +-- .../joint_matrix_bfloat16_32x64_impl.hpp | 3 +-- ...atrix_bfloat16_colmajorA_colmajorB_impl.hpp | 3 +-- .../Legacy/joint_matrix_bfloat16_impl.hpp | 3 +-- ...atrix_bfloat16_rowmajorA_rowmajorB_impl.hpp | 3 +-- .../Matrix/Legacy/joint_matrix_half_impl.hpp | 3 +-- ...nt_matrix_int8_colmajorA_colmajorB_impl.hpp | 3 +-- .../Legacy/joint_matrix_int8_vnni_impl.hpp | 3 +-- .../Legacy/joint_matrix_query_default.cpp | 3 +-- .../Legacy/joint_matrix_ss_int8_impl.hpp | 3 +-- .../Legacy/joint_matrix_su_int8_impl.hpp | 3 +-- .../Legacy/joint_matrix_us_int8_impl.hpp | 3 +-- .../Legacy/joint_matrix_uu_int8_impl.hpp | 3 +-- .../Matrix/joint_matrix_out_bounds_impl.hpp | 3 +-- .../cuda/matrix/matrix-nvptx-bfloat16-test.cpp | 18 ++++++------------ .../cuda/matrix/matrix-nvptx-double-test.cpp | 6 ++---- .../matrix/matrix-nvptx-half-float-test.cpp | 18 ++++++------------ .../matrix/matrix-nvptx-half-half-test.cpp | 18 ++++++------------ .../cuda/matrix/matrix-nvptx-int8-test.cpp | 18 ++++++------------ .../cuda/matrix/matrix-nvptx-tf32-test.cpp | 6 ++---- .../cuda/matrix/matrix-nvptx-uint8-test.cpp | 18 ++++++------------ .../matrix/legacy/matrix-bf16-test-SG-16.cpp | 3 +-- sycl/test/matrix/legacy/matrix-bf16-test.cpp | 3 +-- .../matrix/legacy/matrix-bfloat16-test.cpp | 3 +-- .../test/matrix/legacy/matrix-elemwise-ops.cpp | 3 +-- .../matrix/legacy/matrix-int8-test-SG-16.cpp | 3 +-- sycl/test/matrix/legacy/matrix-int8-test.cpp | 3 +-- 29 files changed, 56 insertions(+), 112 deletions(-) diff --git a/sycl/test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp b/sycl/test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp index dcc5764db1fcb..51eb10095b4a5 100644 --- a/sycl/test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp @@ -76,8 +76,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, - sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); } auto wi_slice_c = sub_c.get_wi_data(); for (int i = 0; i < wi_slice_c.length(); i++) { diff --git a/sycl/test-e2e/Matrix/Legacy/elemwise_irreg_size_ops_bf16.cpp b/sycl/test-e2e/Matrix/Legacy/elemwise_irreg_size_ops_bf16.cpp index d0ee5869b1d54..b6e2e5fbe2315 100644 --- a/sycl/test-e2e/Matrix/Legacy/elemwise_irreg_size_ops_bf16.cpp +++ b/sycl/test-e2e/Matrix/Legacy/elemwise_irreg_size_ops_bf16.cpp @@ -100,8 +100,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k) * (N) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); } auto wi_slice_c = sub_c.get_wi_data(); for (int i = 0; i < wi_slice_c.length(); i++) { diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bf16_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bf16_impl.hpp index 060830074d1fe..847ae955fd41e 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bf16_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bf16_impl.hpp @@ -76,8 +76,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k) * (N) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_32x64_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_32x64_impl.hpp index b8ed30ae495b9..49775fb18d437 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_32x64_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_32x64_impl.hpp @@ -68,8 +68,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp index 5738b3a109a9c..d322210bf7728 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp @@ -62,8 +62,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (sg_starty / SG_SZ * TN) * K + k * TK, K, matrix_layout::col_major); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_impl.hpp index 9358ffb4168cd..610ac5158794a 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_impl.hpp @@ -68,8 +68,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp index 1a567c9147097..11b74d7270d27 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp @@ -62,8 +62,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK) * (N) + sg_starty / SG_SZ * TN, N, matrix_layout::row_major); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_half_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_half_impl.hpp index 86439c83ff840..c1f780e09144e 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_half_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_half_impl.hpp @@ -73,8 +73,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, - sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_colmajorA_colmajorB_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_colmajorA_colmajorB_impl.hpp index a8c97d97aced3..6fb5b3981879e 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_colmajorA_colmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_colmajorA_colmajorB_impl.hpp @@ -72,8 +72,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (sg_starty / SG_SZ * TN) * K + k * TK, K, matrix_layout::col_major); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_vnni_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_vnni_impl.hpp index 9a9d895f4771c..066934b98221c 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_vnni_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_vnni_impl.hpp @@ -64,8 +64,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK) * N + sg_starty / SG_SZ * TN, N, matrix_layout::row_major); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, - sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_query_default.cpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_query_default.cpp index fb3dceee3f361..0dd9cf7e1ec6c 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_query_default.cpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_query_default.cpp @@ -97,8 +97,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp index 1598f8b78b3a2..c1a5d1c762e14 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp @@ -70,8 +70,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, - sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_su_int8_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_su_int8_impl.hpp index 34320622f006c..630708e0b54aa 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_su_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_su_int8_impl.hpp @@ -76,8 +76,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, - sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_us_int8_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_us_int8_impl.hpp index bc0ed40202116..8072f813fdb26 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_us_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_us_int8_impl.hpp @@ -78,8 +78,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp index 164892179f674..6ee550537f285 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp @@ -76,8 +76,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, - sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp index 63eb4af659170..67f99facdd96f 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp @@ -58,8 +58,7 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q, unsigned int vnniFactor) { joint_matrix_load_checked( sg, sub_b, pB + k * N + sg_starty / SG_SZ * TN * vnniFactor, N * vnniFactor, K / vnniFactor, N * vnniFactor); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); } // bounds-checked store where width and height are added joint_matrix_store_checked( diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp index 336db7c2f00e9..784c3ad489cb6 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp @@ -57,8 +57,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -89,8 +88,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -121,8 +119,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -153,8 +150,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -185,8 +181,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -217,8 +212,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp index c8fcccf2d015e..0090805e7a55c 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp @@ -67,8 +67,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), N); //CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.row.row.f64(double {{.*}}, double {{.*}}, double {{.*}}, double {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.row.stride.f64.p1(ptr addrspace(1) %{{.*}}, double {{.*}}, double {{.*}}, i32 8) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -99,8 +98,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), K); //CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.col.col.f64(double {{.*}}, double {{.*}}, double {{.*}}, double {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64.p1(ptr addrspace(1) %{{.*}}, double {{.*}}, double {{.*}}, i32 8) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp index b1e4015460f3f..209341a71e03d 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp @@ -56,8 +56,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.row.row.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -88,8 +87,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.col.col.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -120,8 +118,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.row.row.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -152,8 +149,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.col.col.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -184,8 +180,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.row.row.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -216,8 +211,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.col.col.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp index 0cbc24560c589..e78fbe523dd29 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp @@ -56,8 +56,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -88,8 +87,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -120,8 +118,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -152,8 +149,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -184,8 +180,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -216,8 +211,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp index 0743dc0a7ffc1..743f9fd54e12e 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp @@ -56,8 +56,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.row.row.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -88,8 +87,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.col.col.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -120,8 +118,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.row.row.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -152,8 +149,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.col.col.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -184,8 +180,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.row.row.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -216,8 +211,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.col.col.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp index f6a10c56cd866..d3e28e94e5e71 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp @@ -88,8 +88,7 @@ int main() { get_wi_data(sg, sub_b)[i] = round_to_tf32(get_wi_data(sg, sub_b)[i]); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 {{.*}} joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -138,8 +137,7 @@ int main() { get_wi_data(sg, sub_b)[i] = round_to_tf32(get_wi_data(sg, sub_b)[i]); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp index 448dc86f3321f..1aa82e27f6c68 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp @@ -56,8 +56,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.row.row.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -88,8 +87,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.col.col.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -120,8 +118,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.row.row.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -152,8 +149,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.col.col.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -184,8 +180,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.row.row.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -216,8 +211,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.col.col.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/matrix/legacy/matrix-bf16-test-SG-16.cpp b/sycl/test/matrix/legacy/matrix-bf16-test-SG-16.cpp index b45e32786acf6..c33848a81a2ed 100644 --- a/sycl/test/matrix/legacy/matrix-bf16-test-SG-16.cpp +++ b/sycl/test/matrix/legacy/matrix-bf16-test-SG-16.cpp @@ -89,8 +89,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/legacy/matrix-bf16-test.cpp b/sycl/test/matrix/legacy/matrix-bf16-test.cpp index f2bc1f3c5618e..bd989a6e34d0f 100644 --- a/sycl/test/matrix/legacy/matrix-bf16-test.cpp +++ b/sycl/test/matrix/legacy/matrix-bf16-test.cpp @@ -88,8 +88,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/legacy/matrix-bfloat16-test.cpp b/sycl/test/matrix/legacy/matrix-bfloat16-test.cpp index d41d9152cd157..92715e3b488da 100644 --- a/sycl/test/matrix/legacy/matrix-bfloat16-test.cpp +++ b/sycl/test/matrix/legacy/matrix-bfloat16-test.cpp @@ -87,8 +87,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/legacy/matrix-elemwise-ops.cpp b/sycl/test/matrix/legacy/matrix-elemwise-ops.cpp index a13b79deea9f2..3cb773a2c2239 100644 --- a/sycl/test/matrix/legacy/matrix-elemwise-ops.cpp +++ b/sycl/test/matrix/legacy/matrix-elemwise-ops.cpp @@ -88,8 +88,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); } auto wi_data_c = sub_c.get_wi_data(); for (int i = 0; i < wi_data_c.length(); i++) { diff --git a/sycl/test/matrix/legacy/matrix-int8-test-SG-16.cpp b/sycl/test/matrix/legacy/matrix-int8-test-SG-16.cpp index 09aa3349c38d1..0bc7b66b2e878 100644 --- a/sycl/test/matrix/legacy/matrix-int8-test-SG-16.cpp +++ b/sycl/test/matrix/legacy/matrix-int8-test-SG-16.cpp @@ -88,8 +88,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/legacy/matrix-int8-test.cpp b/sycl/test/matrix/legacy/matrix-int8-test.cpp index 457f5b4747d54..b3fe44fa56250 100644 --- a/sycl/test/matrix/legacy/matrix-int8-test.cpp +++ b/sycl/test/matrix/legacy/matrix-int8-test.cpp @@ -88,8 +88,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, From dae1ec6dabedfb65e92884ecccd162472aeba3cc Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Tue, 19 Sep 2023 14:16:27 +0800 Subject: [PATCH 05/23] fix mad bug --- .../test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp | 2 +- .../Matrix/Legacy/elemwise_irreg_size_ops_bf16.cpp | 2 +- .../Matrix/Legacy/joint_matrix_bf16_impl.hpp | 2 +- .../Legacy/joint_matrix_bfloat16_32x64_impl.hpp | 2 +- ...oint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp | 2 +- .../Matrix/Legacy/joint_matrix_bfloat16_impl.hpp | 2 +- ...oint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp | 2 +- .../Matrix/Legacy/joint_matrix_half_impl.hpp | 2 +- .../joint_matrix_int8_colmajorA_colmajorB_impl.hpp | 2 +- .../Matrix/Legacy/joint_matrix_int8_vnni_impl.hpp | 2 +- .../Matrix/Legacy/joint_matrix_query_default.cpp | 2 +- .../Matrix/Legacy/joint_matrix_ss_int8_impl.hpp | 2 +- .../Matrix/Legacy/joint_matrix_su_int8_impl.hpp | 2 +- .../Matrix/Legacy/joint_matrix_us_int8_impl.hpp | 2 +- .../Matrix/Legacy/joint_matrix_uu_int8_impl.hpp | 2 +- .../test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp | 2 +- .../cuda/matrix/matrix-nvptx-bfloat16-test.cpp | 12 ++++++------ .../cuda/matrix/matrix-nvptx-double-test.cpp | 4 ++-- .../cuda/matrix/matrix-nvptx-half-float-test.cpp | 12 ++++++------ .../cuda/matrix/matrix-nvptx-half-half-test.cpp | 12 ++++++------ .../cuda/matrix/matrix-nvptx-int8-test.cpp | 12 ++++++------ .../cuda/matrix/matrix-nvptx-tf32-test.cpp | 4 ++-- .../cuda/matrix/matrix-nvptx-uint8-test.cpp | 12 ++++++------ sycl/test/matrix/legacy/matrix-bf16-test-SG-16.cpp | 2 +- sycl/test/matrix/legacy/matrix-bf16-test.cpp | 2 +- sycl/test/matrix/legacy/matrix-bfloat16-test.cpp | 2 +- sycl/test/matrix/legacy/matrix-elemwise-ops.cpp | 2 +- sycl/test/matrix/legacy/matrix-int8-test-SG-16.cpp | 2 +- sycl/test/matrix/legacy/matrix-int8-test.cpp | 2 +- 29 files changed, 56 insertions(+), 56 deletions(-) diff --git a/sycl/test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp b/sycl/test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp index 51eb10095b4a5..2ef278e229ff5 100644 --- a/sycl/test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp @@ -76,7 +76,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } auto wi_slice_c = sub_c.get_wi_data(); for (int i = 0; i < wi_slice_c.length(); i++) { diff --git a/sycl/test-e2e/Matrix/Legacy/elemwise_irreg_size_ops_bf16.cpp b/sycl/test-e2e/Matrix/Legacy/elemwise_irreg_size_ops_bf16.cpp index b6e2e5fbe2315..0f57377c571ac 100644 --- a/sycl/test-e2e/Matrix/Legacy/elemwise_irreg_size_ops_bf16.cpp +++ b/sycl/test-e2e/Matrix/Legacy/elemwise_irreg_size_ops_bf16.cpp @@ -100,7 +100,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k) * (N) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } auto wi_slice_c = sub_c.get_wi_data(); for (int i = 0; i < wi_slice_c.length(); i++) { diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bf16_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bf16_impl.hpp index 847ae955fd41e..9868aef0d92e2 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bf16_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bf16_impl.hpp @@ -76,7 +76,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k) * (N) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_32x64_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_32x64_impl.hpp index 49775fb18d437..ac4a0bc405816 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_32x64_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_32x64_impl.hpp @@ -68,7 +68,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp index d322210bf7728..91845ac61a180 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp @@ -62,7 +62,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (sg_starty / SG_SZ * TN) * K + k * TK, K, matrix_layout::col_major); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_impl.hpp index 610ac5158794a..2598905f9f6fe 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_impl.hpp @@ -68,7 +68,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp index 11b74d7270d27..c293d8ff22944 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp @@ -62,7 +62,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK) * (N) + sg_starty / SG_SZ * TN, N, matrix_layout::row_major); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_half_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_half_impl.hpp index c1f780e09144e..81ca8faa4977d 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_half_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_half_impl.hpp @@ -73,7 +73,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_colmajorA_colmajorB_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_colmajorA_colmajorB_impl.hpp index 6fb5b3981879e..21347c80c083b 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_colmajorA_colmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_colmajorA_colmajorB_impl.hpp @@ -72,7 +72,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (sg_starty / SG_SZ * TN) * K + k * TK, K, matrix_layout::col_major); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_vnni_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_vnni_impl.hpp index 066934b98221c..1948071dbf405 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_vnni_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_vnni_impl.hpp @@ -64,7 +64,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK) * N + sg_starty / SG_SZ * TN, N, matrix_layout::row_major); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_query_default.cpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_query_default.cpp index 0dd9cf7e1ec6c..8aaf737a274a8 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_query_default.cpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_query_default.cpp @@ -97,7 +97,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp index c1a5d1c762e14..20e14381f24bf 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp @@ -70,7 +70,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_su_int8_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_su_int8_impl.hpp index 630708e0b54aa..7c80fff72a55f 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_su_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_su_int8_impl.hpp @@ -76,7 +76,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_us_int8_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_us_int8_impl.hpp index 8072f813fdb26..68cf40bb481b9 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_us_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_us_int8_impl.hpp @@ -78,7 +78,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp index 6ee550537f285..895b7f0339cfe 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp @@ -76,7 +76,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp index 67f99facdd96f..1c1c4f97819bf 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp @@ -58,7 +58,7 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q, unsigned int vnniFactor) { joint_matrix_load_checked( sg, sub_b, pB + k * N + sg_starty / SG_SZ * TN * vnniFactor, N * vnniFactor, K / vnniFactor, N * vnniFactor); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } // bounds-checked store where width and height are added joint_matrix_store_checked( diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp index 784c3ad489cb6..958aba55c3b46 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp @@ -57,7 +57,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -88,7 +88,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -119,7 +119,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -150,7 +150,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -181,7 +181,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -212,7 +212,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp index 0090805e7a55c..567e3293e1862 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp @@ -67,7 +67,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), N); //CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.row.row.f64(double {{.*}}, double {{.*}}, double {{.*}}, double {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.row.stride.f64.p1(ptr addrspace(1) %{{.*}}, double {{.*}}, double {{.*}}, i32 8) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -98,7 +98,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), K); //CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.col.col.f64(double {{.*}}, double {{.*}}, double {{.*}}, double {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64.p1(ptr addrspace(1) %{{.*}}, double {{.*}}, double {{.*}}, i32 8) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp index 209341a71e03d..93d431061a3e0 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp @@ -56,7 +56,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.row.row.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -87,7 +87,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.col.col.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -118,7 +118,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.row.row.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -149,7 +149,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.col.col.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -180,7 +180,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.row.row.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -211,7 +211,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.col.col.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp index e78fbe523dd29..ba4d27a2feb89 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp @@ -56,7 +56,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -87,7 +87,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -118,7 +118,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -149,7 +149,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -180,7 +180,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -211,7 +211,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp index 743f9fd54e12e..2d012581faf8b 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp @@ -56,7 +56,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.row.row.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -87,7 +87,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.col.col.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -118,7 +118,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.row.row.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -149,7 +149,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.col.col.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -180,7 +180,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.row.row.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -211,7 +211,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.col.col.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp index d3e28e94e5e71..a69246dac7315 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp @@ -88,7 +88,7 @@ int main() { get_wi_data(sg, sub_b)[i] = round_to_tf32(get_wi_data(sg, sub_b)[i]); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 {{.*}} joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -137,7 +137,7 @@ int main() { get_wi_data(sg, sub_b)[i] = round_to_tf32(get_wi_data(sg, sub_b)[i]); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp index 1aa82e27f6c68..c22a50908d1c7 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp @@ -56,7 +56,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.row.row.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -87,7 +87,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.col.col.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -118,7 +118,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.row.row.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -149,7 +149,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.col.col.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -180,7 +180,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.row.row.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -211,7 +211,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.col.col.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/matrix/legacy/matrix-bf16-test-SG-16.cpp b/sycl/test/matrix/legacy/matrix-bf16-test-SG-16.cpp index c33848a81a2ed..391a9be2197c6 100644 --- a/sycl/test/matrix/legacy/matrix-bf16-test-SG-16.cpp +++ b/sycl/test/matrix/legacy/matrix-bf16-test-SG-16.cpp @@ -89,7 +89,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/legacy/matrix-bf16-test.cpp b/sycl/test/matrix/legacy/matrix-bf16-test.cpp index bd989a6e34d0f..6c6bfc1066f01 100644 --- a/sycl/test/matrix/legacy/matrix-bf16-test.cpp +++ b/sycl/test/matrix/legacy/matrix-bf16-test.cpp @@ -88,7 +88,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/legacy/matrix-bfloat16-test.cpp b/sycl/test/matrix/legacy/matrix-bfloat16-test.cpp index 92715e3b488da..022e69f9b75a2 100644 --- a/sycl/test/matrix/legacy/matrix-bfloat16-test.cpp +++ b/sycl/test/matrix/legacy/matrix-bfloat16-test.cpp @@ -87,7 +87,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/legacy/matrix-elemwise-ops.cpp b/sycl/test/matrix/legacy/matrix-elemwise-ops.cpp index 3cb773a2c2239..feddb05148c4e 100644 --- a/sycl/test/matrix/legacy/matrix-elemwise-ops.cpp +++ b/sycl/test/matrix/legacy/matrix-elemwise-ops.cpp @@ -88,7 +88,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } auto wi_data_c = sub_c.get_wi_data(); for (int i = 0; i < wi_data_c.length(); i++) { diff --git a/sycl/test/matrix/legacy/matrix-int8-test-SG-16.cpp b/sycl/test/matrix/legacy/matrix-int8-test-SG-16.cpp index 0bc7b66b2e878..335529ad3120a 100644 --- a/sycl/test/matrix/legacy/matrix-int8-test-SG-16.cpp +++ b/sycl/test/matrix/legacy/matrix-int8-test-SG-16.cpp @@ -88,7 +88,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/legacy/matrix-int8-test.cpp b/sycl/test/matrix/legacy/matrix-int8-test.cpp index b3fe44fa56250..77c57b4ef711e 100644 --- a/sycl/test/matrix/legacy/matrix-int8-test.cpp +++ b/sycl/test/matrix/legacy/matrix-int8-test.cpp @@ -88,7 +88,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, From 4ec8360e83f5389ba68afaf9cbc2ca70959c4137 Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Tue, 19 Sep 2023 15:35:38 +0800 Subject: [PATCH 06/23] fix cuda const joint_matrix_cuda --- .../sycl/ext/oneapi/matrix/matrix-tensorcores.hpp | 10 +++++----- .../matrix/matrix_load_store_as_legacy.cpp | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp index 94ae318540012..1ab3b56b79ca2 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp @@ -482,11 +482,11 @@ void joint_matrix_mad_cuda( joint_matrix_cuda< Tc, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &D, - joint_matrix_cuda &A, - joint_matrix_cuda &B, - joint_matrix_cuda< + const joint_matrix_cuda &A, + const joint_matrix_cuda &B, + const joint_matrix_cuda< Tc, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &C) { if constexpr (M == 16 && N == 16 && K == 16) { diff --git a/sycl/test/check_device_code/matrix/matrix_load_store_as_legacy.cpp b/sycl/test/check_device_code/matrix/matrix_load_store_as_legacy.cpp index 022ac65612a28..bb18a21bc1002 100644 --- a/sycl/test/check_device_code/matrix/matrix_load_store_as_legacy.cpp +++ b/sycl/test/check_device_code/matrix/matrix_load_store_as_legacy.cpp @@ -47,7 +47,7 @@ int main(void) { // B should load from global address space // CHECK: %{{.*}} = tail call spir_func noundef target("spirv.JointMatrixINTEL", i16, 16, 16, 3, 3) @_Z[[#]]__spirv_JointMatrixLoadINTEL{{.*}}(ptr addrspace(1) noundef %{{.*}}, i64 noundef 32, i32 noundef [[#]], i32 noundef 3, i32 noundef 0) #{{.*}} joint_matrix_load(sg, tB, pB, 32, matrix_layout::packed_b); - joint_matrix_mad(sg, tA, tB, tC, tC); + tC = joint_matrix_mad(sg, tA, tB, tC); // C should store to global address space // CHECK: tail call spir_func void @_Z[[#]]__spirv_JointMatrixStoreINTEL{{.*}}(ptr addrspace(1) noundef %{{.*}}, target("spirv.JointMatrixINTEL", float, 8, 16, 0, 3) noundef %{{.*}}, i64 noundef 16, i32 noundef 0, i32 noundef 3, i32 noundef 0) #{{.*}} joint_matrix_store(sg, tC, pC, 16, matrix_layout::row_major); From a461cbb88c64518d58a58e04218530512042fe8e Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Tue, 19 Sep 2023 16:04:10 +0800 Subject: [PATCH 07/23] fix const issue of jm_store_cuda --- sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp index 1ab3b56b79ca2..ac532dbb1886a 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp @@ -415,7 +415,8 @@ template &src, + NumCols, const sycl::ext::oneapi::experimental::matrix::layout::dynamic> + &src, multi_ptr dst, size_t stride, sycl::ext::oneapi::experimental::matrix::layout Layout) { switch (Layout) { From 5ff715bef9394e9028f00273252da777a5f7d56f Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Tue, 19 Sep 2023 18:07:15 +0800 Subject: [PATCH 08/23] fix const --- sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp index ac532dbb1886a..ca912a04cedc5 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp @@ -413,9 +413,9 @@ void store_layoutT( template void joint_matrix_store_cuda( - joint_matrix_cuda< + const joint_matrix_cuda< T, sycl::ext::oneapi::experimental::matrix::use::accumulator, NumRows, - NumCols, const sycl::ext::oneapi::experimental::matrix::layout::dynamic> + NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &src, multi_ptr dst, size_t stride, sycl::ext::oneapi::experimental::matrix::layout Layout) { From 8ad7da922f4c00763098dd2f160675f161e7b749 Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Tue, 19 Sep 2023 22:06:03 +0800 Subject: [PATCH 09/23] lint --- sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp index ca912a04cedc5..849cb676c6613 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp @@ -415,8 +415,7 @@ template - &src, + NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &src, multi_ptr dst, size_t stride, sycl::ext::oneapi::experimental::matrix::layout Layout) { switch (Layout) { From 26ea49da18fe71a3e9c94e179b3eff735ea1cc3a Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Thu, 21 Sep 2023 14:52:53 +0800 Subject: [PATCH 10/23] address dounia's comments and roll back all the testcase changes --- .../sycl/ext/oneapi/matrix/matrix-unified.hpp | 29 +++--- .../Matrix/Legacy/element_wise_ops_impl.hpp | 96 +++++++++---------- .../Matrix/Legacy/joint_matrix_half_impl.hpp | 84 ++++++++-------- .../Legacy/joint_matrix_ss_int8_impl.hpp | 76 +++++++-------- .../Legacy/joint_matrix_su_int8_impl.hpp | 88 ++++++++--------- .../Legacy/joint_matrix_uu_int8_impl.hpp | 88 ++++++++--------- .../test-e2e/Matrix/element_wise_abc_impl.hpp | 14 +-- .../Matrix/element_wise_all_ops_half_impl.hpp | 15 ++- .../Matrix/element_wise_all_ops_impl.hpp | 6 +- .../Matrix/element_wise_all_ops_int8_impl.hpp | 15 ++- .../element_wise_all_ops_int8_packed_impl.hpp | 40 ++++---- .../Matrix/element_wise_all_ops_tf32_impl.hpp | 15 ++- .../Matrix/element_wise_all_sizes_impl.hpp | 47 ++++----- .../element_wise_irreg_sum_rows_impl.hpp | 8 +- .../test-e2e/Matrix/element_wise_ops_impl.hpp | 10 +- .../Matrix/elemwise_irreg_size_ops_bf16.cpp | 10 +- .../Matrix/joint_matrix_all_sizes_impl.hpp | 7 +- .../Matrix/joint_matrix_apply_cuda.hpp | 65 ++++++------- .../joint_matrix_bf16_fill_k_cache_impl.hpp | 53 +++++----- .../joint_matrix_bfloat16_32x64_impl.hpp | 7 +- .../joint_matrix_bfloat16_array_impl.hpp | 7 +- ...trix_bfloat16_colmajorA_colmajorB_impl.hpp | 2 +- .../Matrix/joint_matrix_bfloat16_impl.hpp | 7 +- ...trix_bfloat16_rowmajorA_rowmajorB_impl.hpp | 2 +- .../joint_matrix_colA_rowB_colC_impl.hpp | 2 +- .../Matrix/joint_matrix_down_convert_impl.hpp | 9 +- .../Matrix/joint_matrix_gemm_cuda.hpp | 2 +- .../Matrix/joint_matrix_half_impl.hpp | 7 +- ...t_matrix_int8_colmajorA_colmajorB_impl.hpp | 2 +- .../Matrix/joint_matrix_int8_vnni_impl.hpp | 2 +- .../Matrix/joint_matrix_query_default.cpp | 5 +- .../Matrix/joint_matrix_ss_int8_impl.hpp | 7 +- .../Matrix/joint_matrix_su_int8_impl.hpp | 7 +- .../Matrix/joint_matrix_tf32_impl.hpp | 8 +- .../Matrix/joint_matrix_transposeC_impl.hpp | 7 +- .../Matrix/joint_matrix_us_int8_impl.hpp | 7 +- .../Matrix/joint_matrix_uu_int8_impl.hpp | 7 +- .../matrix/matrix-nvptx-bfloat16-test.cpp | 12 +-- .../cuda/matrix/matrix-nvptx-double-test.cpp | 4 +- .../matrix/matrix-nvptx-half-float-test.cpp | 12 +-- .../matrix/matrix-nvptx-half-half-test.cpp | 12 +-- .../cuda/matrix/matrix-nvptx-int8-test.cpp | 12 +-- .../cuda/matrix/matrix-nvptx-tf32-test.cpp | 4 +- .../cuda/matrix/matrix-nvptx-uint8-test.cpp | 12 +-- .../matrix/matrix_load_store_as.cpp | 7 +- .../matrix-bfloat16-test-coord-basicB.cpp | 8 +- sycl/test/matrix/matrix-bfloat16-test.cpp | 5 +- sycl/test/matrix/matrix-elemwise-ops.cpp | 8 +- sycl/test/matrix/matrix-int8-test.cpp | 5 +- sycl/test/matrix/matrix-tf32-test.cpp | 5 +- sycl/test/matrix/query-use.cpp | 6 +- 51 files changed, 496 insertions(+), 479 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp index 90f267bff19c4..bf2441ac17c2e 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -87,9 +87,10 @@ class wi_data { #if defined(__NVPTX__) return jm.cuda_impl.wi_marray.size(); #else - throw runtime_error("get_wi_data is available using: " - "ext::oneapi::detail::get_wi_data.", - PI_ERROR_INVALID_DEVICE); + throw runtime_error( + "get_wi_data is available using: ext::oneapi::detail::get_wi_data, but " + "intel users are expected to use joint_matrix_copy instead.", + PI_ERROR_INVALID_DEVICE); #endif }; @@ -97,9 +98,10 @@ class wi_data { #if defined(__NVPTX__) return (jm.cuda_impl.wi_marray[i]); #else - throw runtime_error("get_wi_data is available using: " - "ext::oneapi::detail::get_wi_data.", - PI_ERROR_INVALID_DEVICE); + throw runtime_error( + "get_wi_data is available using: ext::oneapi::detail::get_wi_data, but " + "intel users are expected to use joint_matrix_copy instead.", + PI_ERROR_INVALID_DEVICE); #endif }; }; @@ -129,7 +131,7 @@ __SYCL2020_DEPRECATED("get_wi_data() is deprecated for CUDA backend. Please " #else __attribute__(( unavailable("get_wi_data can't be used on intel device, please use " - "sycl::ext::oneapi::detail::get_wi_data instead!"))) + "joint_matrix_apply instead!"))) #endif #endif inline __SYCL_ALWAYS_INLINE decltype(auto) @@ -251,7 +253,7 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_load( Ptr, stride, __spv::MatrixLayout::ColumnMajor, spv_scope_traits::value); break; - case sycl::ext::oneapi::experimental::matrix::layout::ext_intel_packed: + case layout::ext_intel_packed: res.spvm = __spirv_JointMatrixLoadINTEL< DecorT, S, NumRows, NumCols, spv_matrix_use_traits::value, @@ -351,7 +353,7 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store( Ptr, src.spvm, stride, __spv::MatrixLayout::ColumnMajor, spv_scope_traits::value); break; - case sycl::ext::oneapi::experimental::matrix::layout::ext_intel_packed: + case layout::ext_intel_packed: __spirv_JointMatrixStoreINTEL< DecorT, T, NumRows, NumCols, spv_matrix_use_traits::value, @@ -376,13 +378,14 @@ template inline __SYCL_ALWAYS_INLINE void joint_matrix_mad( - Group sg, const joint_matrix &A, + Group sg, + joint_matrix &D, + const joint_matrix &A, const joint_matrix &B, const joint_matrix - &C, - joint_matrix &D) { + &C) { #if defined(__SYCL_DEVICE_ONLY__) #if defined(__NVPTX__) std::ignore = sg; diff --git a/sycl/test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp b/sycl/test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp index 2ef278e229ff5..8d15b78fd3198 100644 --- a/sycl/test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp @@ -38,56 +38,56 @@ void matrix_multiply(big_matrix &C, cgh.parallel_for( nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), - [accA, accB, accC, M, N, K](nd_item<2> spmd_item) - [[intel::reqd_sub_group_size(SG_SZ)]] { - // The submatrix API has to be accessed by all the workitems in a - // subgroup these functions will be called once by the subgroup - // no code divergence between the workitems - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx - spmd_item.get_local_id(0); - const auto sg_starty = global_idy - spmd_item.get_local_id(1); + [accA, accB, accC, M, N, + K](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + // The submatrix API has to be accessed by all the workitems in a + // subgroup these functions will be called once by the subgroup no + // code divergence between the workitems + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); - sycl::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); - // For B, since current implementation does not support - // non-packed layout, users need to specify the updated VNNI - // sizes along with the packed_b layout. By default, the layout - // is row_major and size is (TK, TN). - joint_matrix sub_b(sg); - joint_matrix sub_c(sg); + sycl::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + // For B, since current implementation does not support non-packed + // layout, users need to specify the updated VNNI sizes along with + // the packed_b layout. By default, the layout is row_major and size + // is (TK, TN). + joint_matrix sub_b(sg); + joint_matrix sub_c(sg); - // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 - // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 - joint_matrix_load( - sg, sub_c, - accC.template get_multi_ptr() + - (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); - for (int k = 0; k < K / TK; k += 1) { - joint_matrix_load( - sg, sub_a, - accA.template get_multi_ptr() + - (sg_startx * TM) * K + k * TK, - K, matrix_layout::row_major); - // Assuming B data is already in VNNI format. - joint_matrix_load( - sg, sub_b, - accB.template get_multi_ptr() + - (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, - N * 4, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); - } - auto wi_slice_c = sub_c.get_wi_data(); - for (int i = 0; i < wi_slice_c.length(); i++) { - wi_slice_c[i] *= 2; - } - joint_matrix_store( - sg, sub_c, - accC.template get_multi_ptr() + - (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); - }); // parallel for + // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 + // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + for (int k = 0; k < K / TK; k += 1) { + joint_matrix_load( + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, + K, matrix_layout::row_major); + // Assuming B data is already in VNNI format. + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, + N * 4, matrix_layout::packed_b); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + } + auto wi_slice_c = sub_c.get_wi_data(); + for (int i = 0; i < wi_slice_c.length(); i++) { + wi_slice_c[i] *= 2; + } + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_half_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_half_impl.hpp index 81ca8faa4977d..c663cc282c758 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_half_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_half_impl.hpp @@ -37,50 +37,50 @@ void matrix_multiply(big_matrix &C, cgh.parallel_for( nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, SG_SZ}), - [accA, accB, accC, M, N, K](nd_item<2> spmd_item) - [[intel::reqd_sub_group_size(SG_SZ)]] { - // The submatrix API has to be accessed by all the workitems in a - // subgroup these functions will be called once by the subgroup - // no code divergence between the workitems - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx - spmd_item.get_local_id(0); - const auto sg_starty = global_idy - spmd_item.get_local_id(1); + [accA, accB, accC, M, N, + K](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + // The submatrix API has to be accessed by all the workitems in a + // subgroup these functions will be called once by the subgroup no + // code divergence between the workitems + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); - sycl::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); - // For B, since current implementation does not support - // non-packed layout, users need to specify the updated VNNI - // sizes along with the packed_b layout. By default, the layout - // is row_major and size is (TK, TN). - joint_matrix sub_b(sg); - joint_matrix sub_c(sg); + sycl::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + // For B, since current implementation does not support non-packed + // layout, users need to specify the updated VNNI sizes along with + // the packed_b layout. By default, the layout is row_major and size + // is (TK, TN). + joint_matrix sub_b(sg); + joint_matrix sub_c(sg); - joint_matrix_load( - sg, sub_c, - accC.template get_multi_ptr() + - (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); - for (int k = 0; k < K / TK; k += 1) { - joint_matrix_load( - sg, sub_a, - accA.template get_multi_ptr() + - (sg_startx * TM) * K + k * TK, - K, matrix_layout::row_major); - // Assuming B data is already in VNNI format. - joint_matrix_load( - sg, sub_b, - accB.template get_multi_ptr() + - (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, - N * 2, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); - } - joint_matrix_store( - sg, sub_c, - accC.template get_multi_ptr() + - (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); - }); // parallel for + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + for (int k = 0; k < K / TK; k += 1) { + joint_matrix_load( + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, + K, matrix_layout::row_major); + // Assuming B data is already in VNNI format. + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, + N * 2, matrix_layout::packed_b); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + } + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp index 20e14381f24bf..a2436bc56e792 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp @@ -38,46 +38,46 @@ void matrix_multiply(big_matrix &C, cgh.parallel_for( nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), - [accA, accB, accC, M, N, K](nd_item<2> spmd_item) - [[intel::reqd_sub_group_size(SG_SZ)]] { - // The submatrix API has to be accessed by all the workitems in a - // subgroup these functions will be called once by the subgroup - // no code divergence between the workitems - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx - spmd_item.get_local_id(0); - const auto sg_starty = global_idy - spmd_item.get_local_id(1); + [accA, accB, accC, M, N, + K](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + // The submatrix API has to be accessed by all the workitems in a + // subgroup these functions will be called once by the subgroup no + // code divergence between the workitems + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); - sycl::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); - // For B, since current implementation does not support - // non-packed layout, users need to specify the updated VNNI - // sizes along with the packed_b layout. By default, the layout - // is row_major and size is (TK, TN). - joint_matrix sub_b(sg); - joint_matrix sub_c(sg); + sycl::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + // For B, since current implementation does not support non-packed + // layout, users need to specify the updated VNNI sizes along with + // the packed_b layout. By default, the layout is row_major and size + // is (TK, TN). + joint_matrix sub_b(sg); + joint_matrix sub_c(sg); - joint_matrix_fill(sg, sub_c, 0); - for (int k = 0; k < K / TK; k += 1) { - joint_matrix_load( - sg, sub_a, - accA.template get_multi_ptr() + - (sg_startx * TM) * K + k * TK, - K, matrix_layout::row_major); - // Assuming B data is already in VNNI format. - joint_matrix_load( - sg, sub_b, - accB.template get_multi_ptr() + - (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, - N * 4, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); - } - joint_matrix_store( - sg, sub_c, - accC.template get_multi_ptr() + - (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); - }); // parallel for + joint_matrix_fill(sg, sub_c, 0); + for (int k = 0; k < K / TK; k += 1) { + joint_matrix_load( + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, + K, matrix_layout::row_major); + // Assuming B data is already in VNNI format. + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, + N * 4, matrix_layout::packed_b); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + } + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_su_int8_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_su_int8_impl.hpp index 7c80fff72a55f..f0a9a7155fb0d 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_su_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_su_int8_impl.hpp @@ -38,52 +38,52 @@ void matrix_multiply(big_matrix &C, cgh.parallel_for( nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), - [accA, accB, accC, M, N, K](nd_item<2> spmd_item) - [[intel::reqd_sub_group_size(SG_SZ)]] { - // The submatrix API has to be accessed by all the workitems in a - // subgroup these functions will be called once by the subgroup - // no code divergence between the workitems - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx - spmd_item.get_local_id(0); - const auto sg_starty = global_idy - spmd_item.get_local_id(1); + [accA, accB, accC, M, N, + K](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + // The submatrix API has to be accessed by all the workitems in a + // subgroup these functions will be called once by the subgroup no + // code divergence between the workitems + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); - sycl::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); - // For B, since current implementation does not support - // non-packed layout, users need to specify the updated VNNI - // sizes along with the packed_b layout. By default, the layout - // is row_major and size is (TK, TN). - joint_matrix sub_b(sg); - joint_matrix sub_c(sg); + sycl::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + // For B, since current implementation does not support non-packed + // layout, users need to specify the updated VNNI sizes along with + // the packed_b layout. By default, the layout is row_major and size + // is (TK, TN). + joint_matrix sub_b(sg); + joint_matrix sub_c(sg); - // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 - // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 - joint_matrix_load( - sg, sub_c, - accC.template get_multi_ptr() + - (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); - for (int k = 0; k < K / TK; k += 1) { - joint_matrix_load( - sg, sub_a, - accA.template get_multi_ptr() + - (sg_startx * TM) * K + k * TK, - K, matrix_layout::row_major); - // Assuming B data is already in VNNI format. - joint_matrix_load( - sg, sub_b, - accB.template get_multi_ptr() + - (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, - N * 4, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); - } - joint_matrix_store( - sg, sub_c, - accC.template get_multi_ptr() + - (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); - }); // parallel for + // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 + // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + for (int k = 0; k < K / TK; k += 1) { + joint_matrix_load( + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, + K, matrix_layout::row_major); + // Assuming B data is already in VNNI format. + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, + N * 4, matrix_layout::packed_b); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + } + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp index 895b7f0339cfe..14190434dd2b1 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp @@ -38,52 +38,52 @@ void matrix_multiply(big_matrix &C, cgh.parallel_for( nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), - [accA, accB, accC, M, N, K](nd_item<2> spmd_item) - [[intel::reqd_sub_group_size(SG_SZ)]] { - // The submatrix API has to be accessed by all the workitems in a - // subgroup these functions will be called once by the subgroup - // no code divergence between the workitems - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx - spmd_item.get_local_id(0); - const auto sg_starty = global_idy - spmd_item.get_local_id(1); + [accA, accB, accC, M, N, + K](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + // The submatrix API has to be accessed by all the workitems in a + // subgroup these functions will be called once by the subgroup no + // code divergence between the workitems + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); - sycl::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); - // For B, since current implementation does not support - // non-packed layout, users need to specify the updated VNNI - // sizes along with the packed_b layout. By default, the layout - // is row_major and size is (TK, TN). - joint_matrix sub_b(sg); - joint_matrix sub_c(sg); + sycl::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + // For B, since current implementation does not support non-packed + // layout, users need to specify the updated VNNI sizes along with + // the packed_b layout. By default, the layout is row_major and size + // is (TK, TN). + joint_matrix sub_b(sg); + joint_matrix sub_c(sg); - // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 - // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 - joint_matrix_load( - sg, sub_c, - accC.template get_multi_ptr() + - (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); - for (int k = 0; k < K / TK; k += 1) { - joint_matrix_load( - sg, sub_a, - accA.template get_multi_ptr() + - (sg_startx * TM) * K + k * TK, - K, matrix_layout::row_major); - // Assuming B data is already in VNNI format. - joint_matrix_load( - sg, sub_b, - accB.template get_multi_ptr() + - (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, - N * 4, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); - } - joint_matrix_store( - sg, sub_c, - accC.template get_multi_ptr() + - (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); - }); // parallel for + // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 + // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + for (int k = 0; k < K / TK; k += 1) { + joint_matrix_load( + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, + K, matrix_layout::row_major); + // Assuming B data is already in VNNI format. + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, + N * 4, matrix_layout::packed_b); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + } + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/element_wise_abc_impl.hpp b/sycl/test-e2e/Matrix/element_wise_abc_impl.hpp index 1da92f4fd4dbc..8b7ee3af2b9c5 100644 --- a/sycl/test-e2e/Matrix/element_wise_abc_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_abc_impl.hpp @@ -55,9 +55,8 @@ void matrix_elem_wise_ops(big_matrix &C, big_matrix &A, sub_group sg = spmd_item.get_sub_group(); joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix< - sub_group, T2, use::b, TK, TN, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + joint_matrix sub_b; joint_matrix sub_c; @@ -66,7 +65,8 @@ void matrix_elem_wise_ops(big_matrix &C, big_matrix &A, accA.template get_multi_ptr() + (sg_startx * TM) * K, K); - auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] += 1; } @@ -76,7 +76,8 @@ void matrix_elem_wise_ops(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + sg_starty / SG_SZ * TN * vnniFactor, N * vnniFactor); - auto wi_slice_b = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); + auto wi_slice_b = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); for (int i = 0; i < wi_slice_b.length(); i++) { wi_slice_b[i] += 1; } @@ -86,7 +87,8 @@ void matrix_elem_wise_ops(big_matrix &C, big_matrix &A, accC.template get_multi_ptr() + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N, layout::row_major); - auto wi_slice_c = sycl::ext::oneapi::detail::get_wi_data(sg, sub_c); + auto wi_slice_c = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); for (int i = 0; i < wi_slice_c.length(); i++) { wi_slice_c[i] += 1; } diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_half_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_half_impl.hpp index df786d73a78ea..540d75c245815 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_half_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_half_impl.hpp @@ -41,7 +41,8 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] + static_cast(2); } @@ -75,7 +76,8 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] - static_cast(2); } @@ -109,7 +111,8 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] * static_cast(3.0); } @@ -143,7 +146,8 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 4); - auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] / static_cast(2.0); } @@ -177,7 +181,8 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { if (wi_slice_a[i]) { if (wi_slice_a[i] > static_cast(2.0) || diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp index 447dff879cb01..8e15488e151a0 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp @@ -63,7 +63,8 @@ void verify_op_a(const T l, const T r, const float ref, OP op) { layout::row_major> sub_mat; joint_matrix_fill(sg, sub_mat, l); - auto wi_slice = sycl::ext::oneapi::detail::get_wi_data(sg, sub_mat); + auto wi_slice = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_mat); for (int i = 0; i < wi_slice.length(); i++) { wi_slice[i] = op(wi_slice[i], r); } @@ -103,7 +104,8 @@ void verify_op_c(const T l, const T r, const float ref, OP op) { joint_matrix sub_mat; joint_matrix_fill(sg, sub_mat, l); - auto wi_slice = sycl::ext::oneapi::detail::get_wi_data(sg, sub_mat); + auto wi_slice = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_mat); for (int i = 0; i < wi_slice.length(); i++) { wi_slice[i] = op(wi_slice[i], r); } diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_int8_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_int8_impl.hpp index c18e371711824..803ebe0addb3a 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_int8_impl.hpp @@ -40,7 +40,8 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] + 2; } @@ -74,7 +75,8 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] - 2; } @@ -108,7 +110,8 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] * 3; } @@ -142,7 +145,8 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 4); - auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] / 2; } @@ -176,7 +180,8 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { if (wi_slice_a[i]) { if (wi_slice_a[i] > 2 || wi_slice_a[i] >= 2 || diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_int8_packed_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_int8_packed_impl.hpp index a04d464605fef..ce89a04b4168c 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_int8_packed_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_int8_packed_impl.hpp @@ -36,14 +36,14 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, const auto sg_starty = global_idy - spmd_item.get_local_id(1); sub_group sg = spmd_item.get_sub_group(); - joint_matrix< - sub_group, int8_t, use::b, TK, TN, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + joint_matrix sub_b; joint_matrix_fill(sg, sub_b, 5); - auto wi_slice_b = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); + auto wi_slice_b = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); for (int i = 0; i < wi_slice_b.length(); i++) { wi_slice_b[i] = wi_slice_b[i] + 2; } @@ -73,14 +73,14 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, const auto sg_starty = global_idy - spmd_item.get_local_id(1); sub_group sg = spmd_item.get_sub_group(); - joint_matrix< - sub_group, int8_t, use::b, TK, TN, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + joint_matrix sub_b; joint_matrix_fill(sg, sub_b, 5); - auto wi_slice_b = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); + auto wi_slice_b = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); for (int i = 0; i < wi_slice_b.length(); i++) { wi_slice_b[i] = wi_slice_b[i] - 2; } @@ -110,14 +110,14 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, const auto sg_starty = global_idy - spmd_item.get_local_id(1); sub_group sg = spmd_item.get_sub_group(); - joint_matrix< - sub_group, int8_t, use::b, TK, TN, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + joint_matrix sub_b; joint_matrix_fill(sg, sub_b, 5); - auto wi_slice_b = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); + auto wi_slice_b = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); for (int i = 0; i < wi_slice_b.length(); i++) { wi_slice_b[i] = wi_slice_b[i] * 3; } @@ -147,14 +147,14 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, const auto sg_starty = global_idy - spmd_item.get_local_id(1); sub_group sg = spmd_item.get_sub_group(); - joint_matrix< - sub_group, int8_t, use::b, TK, TN, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + joint_matrix sub_b; joint_matrix_fill(sg, sub_b, 4); - auto wi_slice_b = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); + auto wi_slice_b = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); for (int i = 0; i < wi_slice_b.length(); i++) { wi_slice_b[i] = wi_slice_b[i] / 2; } @@ -184,14 +184,14 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, const auto sg_starty = global_idy - spmd_item.get_local_id(1); sub_group sg = spmd_item.get_sub_group(); - joint_matrix< - sub_group, int8_t, use::b, TK, TN, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + joint_matrix sub_b; joint_matrix_fill(sg, sub_b, 5); - auto wi_slice_b = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); + auto wi_slice_b = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); for (int i = 0; i < wi_slice_b.length(); i++) { if (wi_slice_b[i]) { if (wi_slice_b[i] > 2 || wi_slice_b[i] >= 2 || diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_tf32_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_tf32_impl.hpp index e73056639eb74..27eacf89c748a 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_tf32_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_tf32_impl.hpp @@ -41,7 +41,8 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, round_to_tf32(5.0)); - auto wi_slice_a = ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = + ext::intel::experimental::matrix::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] + 2; } @@ -76,7 +77,8 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, round_to_tf32(5.0)); - auto wi_slice_a = ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = + ext::intel::experimental::matrix::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] - round_to_tf32(2); } @@ -109,7 +111,8 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, joint_matrix sub_a; joint_matrix_fill(sg, sub_a, round_to_tf32(5.0)); - auto wi_slice_a = ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = + ext::intel::experimental::matrix::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] * round_to_tf32(3.0); } @@ -143,7 +146,8 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, round_to_tf32(4.0)); - auto wi_slice_a = ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = + ext::intel::experimental::matrix::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] / round_to_tf32(2.0); } @@ -176,7 +180,8 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, round_to_tf32(5.0)); - auto wi_slice_a = ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = + ext::intel::experimental::matrix::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { if (wi_slice_a[i]) { if (wi_slice_a[i] > 2.0 || wi_slice_a[i] >= 2.0 || diff --git a/sycl/test-e2e/Matrix/element_wise_all_sizes_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_sizes_impl.hpp index f1aacbef1230f..c49f9b57e2f32 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_sizes_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_sizes_impl.hpp @@ -49,29 +49,30 @@ void matrix_verify_add(const T1 val1, const T1 val2, const T1 result) { q.submit([&](handler &cgh) { sycl::accessor accA{bufA, cgh, sycl::read_write}; - cgh.parallel_for( - r, [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx - spmd_item.get_local_id(0); - const auto sg_starty = global_idy - spmd_item.get_local_id(1); - - sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a; - - joint_matrix_fill(sg, sub_a, val1); - - auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] = wi_slice_a[i] + val2; - } - - ext::intel::experimental::matrix::joint_matrix_store( - sg, sub_a, - accA.template get_multi_ptr() + - (sg_startx * TM) * K + sg_starty / SG_SZ * TK, - K); - }); // parallel for + cgh.parallel_for(r, [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size( + SG_SZ)]] { + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a; + + joint_matrix_fill(sg, sub_a, val1); + + auto wi_slice_a = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); + for (int i = 0; i < wi_slice_a.length(); i++) { + wi_slice_a[i] = wi_slice_a[i] + val2; + } + + ext::intel::experimental::matrix::joint_matrix_store( + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + sg_starty / SG_SZ * TK, + K); + }); // parallel for }).wait(); assert_ops_ref(bufA.get_host_access(), result); } diff --git a/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp b/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp index 6e3a86d7cb77b..cfce95cba269f 100644 --- a/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp @@ -44,9 +44,8 @@ void matrix_sum_rows(queue q, big_matrix &B, nd_range<2> &r) { sycl::sub_group sg = spmd_item.get_sub_group(); - joint_matrix< - sub_group, T, use::b, TK, TN, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + joint_matrix sub_b; joint_matrix_load( @@ -58,7 +57,8 @@ void matrix_sum_rows(queue q, big_matrix &B, nd_range<2> &r) { // (tK/4) int32_t sum_local_rows[M] = {0}; // 8 local rows, M total // sub_b has 32x8 elements, 32 elements per WI, 4 per WI per row - auto data = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); + auto data = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); // each WI calculates local sum of rows for (int row = 0; row < TK / 4; row++) { // there are 8 rows diff --git a/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp b/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp index edf9f162d5c3a..1206b556339a9 100644 --- a/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp @@ -51,9 +51,8 @@ void matrix_multiply(big_matrix &C, joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix< - sub_group, int8_t, use::b, TK, TN, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + joint_matrix sub_b; joint_matrix sub_c; @@ -73,9 +72,10 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - auto wi_slice_c = sycl::ext::oneapi::detail::get_wi_data(sg, sub_c); + auto wi_slice_c = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); for (int i = 0; i < wi_slice_c.length(); i++) { wi_slice_c[i] *= 2; } diff --git a/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp b/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp index 08724fbb48e9d..8e2865de207b4 100644 --- a/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp +++ b/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp @@ -80,9 +80,8 @@ void matrix_multiply(big_matrix &C, joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix< - sub_group, bfloat16, use::b, TK, TN, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + joint_matrix sub_b; joint_matrix sub_c; joint_matrix_load( @@ -102,9 +101,10 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k) * (N) + sg_starty / SG_SZ * TN * 2, N * 2); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - auto wi_slice_c = sycl::ext::oneapi::detail::get_wi_data(sg, sub_c); + auto wi_slice_c = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); for (int i = 0; i < wi_slice_c.length(); i++) { wi_slice_c[i] += 5.0; } diff --git a/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp index f4572990ded76..469837cd26d41 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp @@ -56,9 +56,8 @@ void matrix_multiply(big_matrix &C, big_matrix &A, sub_group sg = spmd_item.get_sub_group(); joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix< - sub_group, T2, use::b, TK, TN, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + joint_matrix sub_b; joint_matrix sub_c; @@ -79,7 +78,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, (k * TK / vnniFactor) * (N * vnniFactor) + sg_starty / SG_SZ * TN * vnniFactor, N * vnniFactor); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp b/sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp index b58e0a1b7c467..4303239cefe32 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp @@ -51,34 +51,34 @@ void matrix_verify_lambda(queue q, q.submit([&](handler &cgh) { accessor accC(bufC, cgh); - cgh.parallel_for>( - r, [accC, lambda]( - nd_item<2> spmd_item) [[sycl::reqd_sub_group_size(SG_SZ)]] { - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx - spmd_item.get_local_id(0); - const auto sg_starty = global_idy - spmd_item.get_local_id(1); - - auto sg = spmd_item.get_sub_group(); - - joint_matrix sub_a; - joint_matrix sub_b; - joint_matrix sub_c; - - joint_matrix_fill(sg, sub_a, 3); - joint_matrix_fill(sg, sub_b, 1); - joint_matrix_fill(sg, sub_c, -80); - - joint_matrix_apply(sg, sub_a, lambda); - - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); - - joint_matrix_store( - sg, sub_c, - accC.template get_multi_ptr() + - (sg_startx * M) * (N * nWGperDim) + sg_starty / SG_SZ * N, - (N * nWGperDim), layout::row_major); - }); // parallel for + cgh.parallel_for>(r, [ + accC, lambda + ](nd_item<2> spmd_item)[[sycl::reqd_sub_group_size(SG_SZ)]] { + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + auto sg = spmd_item.get_sub_group(); + + joint_matrix sub_a; + joint_matrix sub_b; + joint_matrix sub_c; + + joint_matrix_fill(sg, sub_a, 3); + joint_matrix_fill(sg, sub_b, 1); + joint_matrix_fill(sg, sub_c, -80); + + joint_matrix_apply(sg, sub_a, lambda); + + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * M) * (N * nWGperDim) + sg_starty / SG_SZ * N, + (N * nWGperDim), layout::row_major); + }); // parallel for }); } assert_ref(C.get_data(), ref); @@ -111,8 +111,8 @@ void matrix_verify_op(queue q, big_matrix &C, cgh); cgh.parallel_for>( - r, [accC, - Op](nd_item<2> spmd_item) [[sycl::reqd_sub_group_size(SG_SZ)]] { + r, [ accC, + Op ](nd_item<2> spmd_item)[[sycl::reqd_sub_group_size(SG_SZ)]] { const auto global_idx = spmd_item.get_global_id(0); const auto global_idy = spmd_item.get_global_id(1); const auto sg_startx = global_idx - spmd_item.get_local_id(0); @@ -154,7 +154,7 @@ void matrix_verify_op(queue q, big_matrix &C, } } - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); joint_matrix_store( sg, sub_c, @@ -162,7 +162,8 @@ void matrix_verify_op(queue q, big_matrix &C, (sg_startx * M) * (N * nWGperDim) + sg_starty / SG_SZ * N, (N * nWGperDim), layout::row_major); }); // parallel for - }).wait(); + }) + .wait(); } assert_ops_ref(C.get_data(), ref); } diff --git a/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp index 74af36f0238ef..f2d359cfe130b 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp @@ -147,37 +147,36 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i) { #endif ; - joint_matrix< - sub_group, TOperand, use::b, tK, tN, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + joint_matrix tB[NCACHE1 / tN][KCACHE2 / KCACHE1] #ifdef INIT_LIST = { - joint_matrix(), - joint_matrix(), - joint_matrix(), - joint_matrix(), - joint_matrix(), - joint_matrix(), - joint_matrix(), - joint_matrix(), + joint_matrix< + sub_group, TOperand, use::b, tK, tN, + ext::intel::experimental::matrix::layout::packed>(), + joint_matrix< + sub_group, TOperand, use::b, tK, tN, + ext::intel::experimental::matrix::layout::packed>(), + joint_matrix< + sub_group, TOperand, use::b, tK, tN, + ext::intel::experimental::matrix::layout::packed>(), + joint_matrix< + sub_group, TOperand, use::b, tK, tN, + ext::intel::experimental::matrix::layout::packed>(), + joint_matrix< + sub_group, TOperand, use::b, tK, tN, + ext::intel::experimental::matrix::layout::packed>(), + joint_matrix< + sub_group, TOperand, use::b, tK, tN, + ext::intel::experimental::matrix::layout::packed>(), + joint_matrix< + sub_group, TOperand, use::b, tK, tN, + ext::intel::experimental::matrix::layout::packed>(), + joint_matrix< + sub_group, TOperand, use::b, tK, tN, + ext::intel::experimental::matrix::layout::packed>(), } #endif ; diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64_impl.hpp index c27df0c01caa4..cc0196660744a 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64_impl.hpp @@ -46,9 +46,8 @@ void matrix_multiply(big_matrix &C, big_matrix &A, joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix< - sub_group, bfloat16, use::b, TK, TN, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + joint_matrix sub_b; joint_matrix sub_c; @@ -69,7 +68,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_array_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_array_impl.hpp index 9cdf9b2435f82..d6390d8061dcc 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_array_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_array_impl.hpp @@ -59,9 +59,8 @@ void matrix_multiply(big_matrix &C, big_matrix &A, sub_a[JM_ARRAY_SZ]; // For B, we assume B has been already VNNIed. - joint_matrix< - sub_group, bfloat16, use::b, TK, TN, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + joint_matrix sub_b; joint_matrix sub_c[JM_ARRAY_SZ]; @@ -82,7 +81,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accA.template get_multi_ptr() + (sg_startx * TM * JM_ARRAY_SZ + TM * i) * K + k * TK, K); - joint_matrix_mad(sg, sub_a[i], sub_b, sub_c[i], sub_c[i]); + sub_c[i] = joint_matrix_mad(sg, sub_a[i], sub_b, sub_c[i]); } } diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp index 2c627d8b88e31..4847e093127a8 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp @@ -64,7 +64,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (sg_starty / SG_SZ * TN) * K + k * TK, K); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp index a35dd11ee8cfa..76ac69da27677 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp @@ -45,9 +45,8 @@ void matrix_multiply(big_matrix &C, big_matrix &A, joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix< - sub_group, bfloat16, use::b, TK, TN, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + joint_matrix sub_b; joint_matrix sub_c; @@ -67,7 +66,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp index 08341b3835d57..4d61a733e5927 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp @@ -64,7 +64,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK) * (N) + sg_starty / SG_SZ * TN, N); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_colA_rowB_colC_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_colA_rowB_colC_impl.hpp index a90a46e258452..f75da1824d94b 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_colA_rowB_colC_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_colA_rowB_colC_impl.hpp @@ -48,7 +48,7 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q) { joint_matrix_load(sg, sub_a, pA + (sg_startx * TM) * K + k, K); joint_matrix_load(sg, sub_b, pB + k * N + sg_starty / SG_SZ * TN, N); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, pC + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N, diff --git a/sycl/test-e2e/Matrix/joint_matrix_down_convert_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_down_convert_impl.hpp index 5dd21cfe5340b..6972e3854c8e8 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_down_convert_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_down_convert_impl.hpp @@ -49,7 +49,14 @@ void matrix_copy(big_matrix &C, big_matrix &A) { (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N, layout::row_major); // This will be replaced by joint_matrix_copy API - joint_matrix_copy(sg, sub_c, sub_a); + // joint_matrix_copy(sg, sub_c, sub_ac); + auto wi_slice_c = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); + auto wi_slice_a = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); + for (int i = 0; i < wi_slice_c.length(); i++) { + wi_slice_a[i] = (bfloat16)wi_slice_c[i]; + } ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + diff --git a/sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp b/sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp index 244311f7cc9ee..5e451d45d7727 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp @@ -178,7 +178,7 @@ void test(queue &q) { } } - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp index fcfb87545ff6b..f92548d2f7ed8 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp @@ -50,9 +50,8 @@ void matrix_multiply(big_matrix &C, joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix< - sub_group, half, use::b, TK, TN, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + joint_matrix sub_b; joint_matrix sub_c; @@ -72,7 +71,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_int8_colmajorA_colmajorB_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_int8_colmajorA_colmajorB_impl.hpp index a5b1fddcbbbb8..6111f503007f5 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_int8_colmajorA_colmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_int8_colmajorA_colmajorB_impl.hpp @@ -64,7 +64,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (sg_starty / SG_SZ * TN) * K + k * TK, K); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_int8_vnni_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_int8_vnni_impl.hpp index e042a8e282b56..b6fe3f0376ffd 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_int8_vnni_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_int8_vnni_impl.hpp @@ -65,7 +65,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK) * N + sg_starty / SG_SZ * TN, N); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp b/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp index b44661a1f269f..048aed6341f6c 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp +++ b/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp @@ -78,8 +78,7 @@ void matrix_multiply(big_matrix &C, myparams2::joint_matrix_a sub_a; myparams2::joint_matrix_b< - sub_group, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + sub_group, ext::intel::experimental::matrix::layout::packed> sub_b; myparams2::joint_matrix_accumulator sub_c; @@ -100,7 +99,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_ss_int8_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_ss_int8_impl.hpp index 5820544722f55..5cca6572cef21 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_ss_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_ss_int8_impl.hpp @@ -51,9 +51,8 @@ void matrix_multiply(big_matrix &C, joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix< - sub_group, int8_t, use::b, TK, TN, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + joint_matrix sub_b; joint_matrix sub_c; @@ -69,7 +68,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_su_int8_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_su_int8_impl.hpp index 628735e8523e2..397fcc9a5aa97 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_su_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_su_int8_impl.hpp @@ -51,9 +51,8 @@ void matrix_multiply(big_matrix &C, joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix< - sub_group, uint8_t, use::b, TK, TN, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + joint_matrix sub_b; joint_matrix sub_c; @@ -73,7 +72,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_tf32_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_tf32_impl.hpp index e8a4ece3bb00f..4d4ba0ee951e9 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_tf32_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_tf32_impl.hpp @@ -76,13 +76,15 @@ void matrix_multiply(big_matrix &C, // function will work on truncated floats. joint_matrix_apply(sg, sub_a, [=](float x) { x = round_to_tf32(x); }); - auto wi_data_b = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); + auto wi_data_b = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); for (int i = 0; i < wi_data_b.length(); i++) { wi_data_b[i] = round_to_tf32(wi_data_b[i]); } - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); joint_matrix_apply(sg, sub_a, [=](float x) { x *= 2; }); joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_transposeC_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_transposeC_impl.hpp index 8634efde40c74..24f6cce4cc09d 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_transposeC_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_transposeC_impl.hpp @@ -42,9 +42,8 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q, unsigned int vnniFactor) { // For B, since current implementation does not support non-packed // layout, users need to specify the packed_b layout. - joint_matrix< - sub_group, bfloat16, use::b, TK, TN, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + joint_matrix sub_b; joint_matrix sub_c; joint_matrix_load(sg, sub_c, @@ -56,7 +55,7 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q, unsigned int vnniFactor) { joint_matrix_load(sg, sub_b, pB + k * N + sg_starty / SG_SZ * TN * vnniFactor, N * vnniFactor); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, pC + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N, diff --git a/sycl/test-e2e/Matrix/joint_matrix_us_int8_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_us_int8_impl.hpp index 493ece5bc7d5e..1d82f8833aba6 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_us_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_us_int8_impl.hpp @@ -53,9 +53,8 @@ void matrix_multiply(big_matrix &C, joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix< - sub_group, int8_t, use::b, TK, TN, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + joint_matrix sub_b; joint_matrix sub_c; @@ -76,7 +75,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_uu_int8_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_uu_int8_impl.hpp index d668fa604d395..e400b6694e4a9 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_uu_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_uu_int8_impl.hpp @@ -51,9 +51,8 @@ void matrix_multiply(big_matrix &C, joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix< - sub_group, uint8_t, use::b, TK, TN, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + joint_matrix sub_b; joint_matrix sub_c; @@ -74,7 +73,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp index 958aba55c3b46..80b67b14a55ac 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp @@ -57,7 +57,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -88,7 +88,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -119,7 +119,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -150,7 +150,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -181,7 +181,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -212,7 +212,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp index 567e3293e1862..31f77dc55b16f 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp @@ -67,7 +67,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), N); //CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.row.row.f64(double {{.*}}, double {{.*}}, double {{.*}}, double {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.row.stride.f64.p1(ptr addrspace(1) %{{.*}}, double {{.*}}, double {{.*}}, i32 8) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -98,7 +98,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), K); //CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.col.col.f64(double {{.*}}, double {{.*}}, double {{.*}}, double {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64.p1(ptr addrspace(1) %{{.*}}, double {{.*}}, double {{.*}}, i32 8) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp index 93d431061a3e0..7c24179022d55 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp @@ -56,7 +56,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.row.row.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -87,7 +87,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.col.col.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -118,7 +118,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.row.row.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -149,7 +149,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.col.col.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -180,7 +180,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.row.row.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -211,7 +211,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.col.col.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp index ba4d27a2feb89..0e0b4ce903be2 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp @@ -56,7 +56,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -87,7 +87,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -118,7 +118,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -149,7 +149,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -180,7 +180,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -211,7 +211,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp index 2d012581faf8b..575039723d56e 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp @@ -56,7 +56,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.row.row.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -87,7 +87,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.col.col.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -118,7 +118,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.row.row.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -149,7 +149,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.col.col.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -180,7 +180,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.row.row.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -211,7 +211,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.col.col.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp index a69246dac7315..8ada375fff395 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp @@ -88,7 +88,7 @@ int main() { get_wi_data(sg, sub_b)[i] = round_to_tf32(get_wi_data(sg, sub_b)[i]); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 {{.*}} joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -137,7 +137,7 @@ int main() { get_wi_data(sg, sub_b)[i] = round_to_tf32(get_wi_data(sg, sub_b)[i]); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp index c22a50908d1c7..69bc136e79776 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp @@ -56,7 +56,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.row.row.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -87,7 +87,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.col.col.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -118,7 +118,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.row.row.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -149,7 +149,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.col.col.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -180,7 +180,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.row.row.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -211,7 +211,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.col.col.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/matrix/matrix_load_store_as.cpp b/sycl/test/check_device_code/matrix/matrix_load_store_as.cpp index 4c3c97c77edd8..e5935b8b3af47 100644 --- a/sycl/test/check_device_code/matrix/matrix_load_store_as.cpp +++ b/sycl/test/check_device_code/matrix/matrix_load_store_as.cpp @@ -29,9 +29,8 @@ int main(void) { joint_matrix tA; - joint_matrix< - sub_group, unsigned short, use::b, 16, 16, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + joint_matrix tB; joint_matrix tC; @@ -50,7 +49,7 @@ int main(void) { // B should load from global address space // CHECK: %{{.*}} = tail call spir_func noundef target("spirv.JointMatrixINTEL", i16, 16, 16, 2, 3, 1) @_Z[[#]]__spirv_JointMatrixLoadINTEL{{.*}}(ptr addrspace(1) noundef %{{.*}}, i64 noundef 32, i32 noundef 2, i32 noundef 3, i32 noundef 0) #{{.*}} joint_matrix_load(sg, tB, pB, 32); - joint_matrix_mad(sg, tA, tB, tC, tC); + tC = joint_matrix_mad(sg, tA, tB, tC); // C should store to global address space // CHECK: tail call spir_func void @_Z[[#]]__spirv_JointMatrixStoreINTEL{{.*}}(ptr addrspace(1) noundef %{{.*}}, target("spirv.JointMatrixINTEL", float, 8, 16, 3, 3, 2) noundef %{{.*}}, i64 noundef 16, i32 noundef 0, i32 noundef 3, i32 noundef 0) #{{.*}} joint_matrix_store(sg, tC, pC, 16, layout::row_major); diff --git a/sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp b/sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp index 769efbdb0d959..b141f2176971d 100644 --- a/sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp +++ b/sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp @@ -154,9 +154,8 @@ void matrix_sum_cols(queue q, big_matrix &B, nd_range<2> &r) { sub_group sg = spmd_item.get_sub_group(); // TK = 32, TN = 16 - joint_matrix< - sub_group, int8_t, use::b, TK, TN, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + joint_matrix sub_b; joint_matrix_load( @@ -167,7 +166,8 @@ void matrix_sum_cols(queue q, big_matrix &B, nd_range<2> &r) { int32_t sum_local_cols[N] = {0}; // 4 local cols, N total // sub_b has 32x16 elements, 32 elements per WI, 4 per WI per row - auto wiData = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); + auto wiData = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); size_t global_index; // Index into the result array that holds the sums. diff --git a/sycl/test/matrix/matrix-bfloat16-test.cpp b/sycl/test/matrix/matrix-bfloat16-test.cpp index eae42973d45c4..2e0e309081464 100644 --- a/sycl/test/matrix/matrix-bfloat16-test.cpp +++ b/sycl/test/matrix/matrix-bfloat16-test.cpp @@ -68,8 +68,7 @@ void matrix_multiply(big_matrix &C, // the packed_b layout. By default, the layout is row_major and size // is (TK, TN). joint_matrix + sycl::ext::intel::experimental::matrix::layout::packed> sub_b; joint_matrix sub_c; @@ -90,7 +89,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/matrix-elemwise-ops.cpp b/sycl/test/matrix/matrix-elemwise-ops.cpp index 861727c3fe92d..3205e4c346ba6 100644 --- a/sycl/test/matrix/matrix-elemwise-ops.cpp +++ b/sycl/test/matrix/matrix-elemwise-ops.cpp @@ -69,8 +69,7 @@ void matrix_multiply(big_matrix &C, // the packed_b layout. By default, the layout is row_major and size // is (TK, TN). joint_matrix + sycl::ext::intel::experimental::matrix::layout::packed> sub_b; joint_matrix sub_c; @@ -94,9 +93,10 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, sub_c); + auto wi_data_c = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); for (int i = 0; i < wi_data_c.length(); i++) { wi_data_c[i] *= 2; } diff --git a/sycl/test/matrix/matrix-int8-test.cpp b/sycl/test/matrix/matrix-int8-test.cpp index 959f5b2b30871..f8dcc26ab1b17 100644 --- a/sycl/test/matrix/matrix-int8-test.cpp +++ b/sycl/test/matrix/matrix-int8-test.cpp @@ -74,8 +74,7 @@ void matrix_multiply(big_matrix &C, // the packed_b layout. By default, the layout is row_major and size // is (TK, TN). joint_matrix + sycl::ext::intel::experimental::matrix::layout::packed> sub_b; joint_matrix sub_c; @@ -95,7 +94,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/matrix-tf32-test.cpp b/sycl/test/matrix/matrix-tf32-test.cpp index 64852b4a890cc..d6affb4067003 100644 --- a/sycl/test/matrix/matrix-tf32-test.cpp +++ b/sycl/test/matrix/matrix-tf32-test.cpp @@ -87,11 +87,12 @@ void matrix_multiply(big_matrix &C, // function will work on truncated floats. joint_matrix_apply(sg, sub_a, [=](float x) { x = round_to_tf32(x); }); - auto wi_data_b = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); + auto wi_data_b = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); for (int i = 0; i < wi_data_b.length(); i++) { wi_data_b[i] = round_to_tf32(wi_data_b[i]); } - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/query-use.cpp b/sycl/test/matrix/query-use.cpp index 96d2fd3c0c26a..9afc8e1173043 100644 --- a/sycl/test/matrix/query-use.cpp +++ b/sycl/test/matrix/query-use.cpp @@ -64,8 +64,7 @@ void query_amx() { sub_group sg = spmd_item.get_sub_group(); myparams2::joint_matrix_a sub_a1; myparams2::joint_matrix_b< - sub_group, - sycl::ext::oneapi::experimental::matrix::layout::ext_intel_packed> + sub_group, sycl::ext::intel::experimental::matrix::layout::packed> sub_b1; myparams2::joint_matrix_accumulator sub_c1; @@ -145,8 +144,7 @@ void query_xmx8() { sub_group sg = spmd_item.get_sub_group(); myparams2::joint_matrix_a sub_a1; myparams2::joint_matrix_b< - sub_group, - sycl::ext::oneapi::experimental::matrix::layout::ext_intel_packed> + sub_group, sycl::ext::intel::experimental::matrix::layout::packed> sub_b1; myparams2::joint_matrix_accumulator sub_c1; From a09a778416f23bec6e4f4f4db185b5baed4c8079 Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Thu, 21 Sep 2023 15:08:16 +0800 Subject: [PATCH 11/23] test changes: mov D in mad --- sycl/test-e2e/Matrix/element_wise_ops_impl.hpp | 2 +- .../test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp | 2 +- sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp | 2 +- sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp | 4 ++-- .../Matrix/joint_matrix_bfloat16_32x64_impl.hpp | 2 +- .../Matrix/joint_matrix_bfloat16_array_impl.hpp | 2 +- ...oint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp | 2 +- sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp | 2 +- ...oint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp | 2 +- .../Matrix/joint_matrix_colA_rowB_colC_impl.hpp | 2 +- sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp | 2 +- sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp | 2 +- .../joint_matrix_int8_colmajorA_colmajorB_impl.hpp | 2 +- sycl/test-e2e/Matrix/joint_matrix_int8_vnni_impl.hpp | 2 +- .../test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp | 2 +- sycl/test-e2e/Matrix/joint_matrix_query_default.cpp | 2 +- sycl/test-e2e/Matrix/joint_matrix_ss_int8_impl.hpp | 2 +- sycl/test-e2e/Matrix/joint_matrix_su_int8_impl.hpp | 2 +- sycl/test-e2e/Matrix/joint_matrix_tf32_impl.hpp | 2 +- .../test-e2e/Matrix/joint_matrix_transposeC_impl.hpp | 2 +- sycl/test-e2e/Matrix/joint_matrix_us_int8_impl.hpp | 2 +- sycl/test-e2e/Matrix/joint_matrix_uu_int8_impl.hpp | 2 +- .../cuda/matrix/matrix-nvptx-bfloat16-test.cpp | 12 ++++++------ .../cuda/matrix/matrix-nvptx-double-test.cpp | 4 ++-- .../cuda/matrix/matrix-nvptx-half-float-test.cpp | 12 ++++++------ .../cuda/matrix/matrix-nvptx-half-half-test.cpp | 12 ++++++------ .../cuda/matrix/matrix-nvptx-int8-test.cpp | 12 ++++++------ .../cuda/matrix/matrix-nvptx-tf32-test.cpp | 4 ++-- .../cuda/matrix/matrix-nvptx-uint8-test.cpp | 12 ++++++------ .../matrix/matrix_load_store_as.cpp | 2 +- sycl/test/matrix/matrix-bfloat16-test.cpp | 2 +- sycl/test/matrix/matrix-elemwise-ops.cpp | 2 +- sycl/test/matrix/matrix-int8-test.cpp | 2 +- sycl/test/matrix/matrix-tf32-test.cpp | 2 +- 34 files changed, 62 insertions(+), 62 deletions(-) diff --git a/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp b/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp index 1206b556339a9..66fb8237c2be3 100644 --- a/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp @@ -72,7 +72,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } auto wi_slice_c = sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); diff --git a/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp b/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp index 8e2865de207b4..a74288c71ae16 100644 --- a/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp +++ b/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp @@ -101,7 +101,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k) * (N) + sg_starty / SG_SZ * TN * 2, N * 2); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } auto wi_slice_c = sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); diff --git a/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp index 469837cd26d41..dbd88e607f243 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp @@ -78,7 +78,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, (k * TK / vnniFactor) * (N * vnniFactor) + sg_starty / SG_SZ * TN * vnniFactor, N * vnniFactor); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp b/sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp index 4303239cefe32..754e6429d0d96 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp @@ -71,7 +71,7 @@ void matrix_verify_lambda(queue q, joint_matrix_apply(sg, sub_a, lambda); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); joint_matrix_store( sg, sub_c, @@ -154,7 +154,7 @@ void matrix_verify_op(queue q, big_matrix &C, } } - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64_impl.hpp index cc0196660744a..264961a05ad96 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64_impl.hpp @@ -68,7 +68,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_array_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_array_impl.hpp index d6390d8061dcc..d1c9939318551 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_array_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_array_impl.hpp @@ -81,7 +81,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accA.template get_multi_ptr() + (sg_startx * TM * JM_ARRAY_SZ + TM * i) * K + k * TK, K); - sub_c[i] = joint_matrix_mad(sg, sub_a[i], sub_b, sub_c[i]); + joint_matrix_mad(sg, sub_c[i], sub_a[i], sub_b, sub_c[i]); } } diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp index 4847e093127a8..7c07afcb3ecb7 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp @@ -64,7 +64,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (sg_starty / SG_SZ * TN) * K + k * TK, K); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp index 76ac69da27677..4889e77812d72 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp @@ -66,7 +66,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp index 4d61a733e5927..119554c9b23ad 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp @@ -64,7 +64,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK) * (N) + sg_starty / SG_SZ * TN, N); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_colA_rowB_colC_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_colA_rowB_colC_impl.hpp index f75da1824d94b..f24f720715788 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_colA_rowB_colC_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_colA_rowB_colC_impl.hpp @@ -48,7 +48,7 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q) { joint_matrix_load(sg, sub_a, pA + (sg_startx * TM) * K + k, K); joint_matrix_load(sg, sub_b, pB + k * N + sg_starty / SG_SZ * TN, N); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, pC + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N, diff --git a/sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp b/sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp index 5e451d45d7727..219a3976f4c90 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp @@ -178,7 +178,7 @@ void test(queue &q) { } } - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp index f92548d2f7ed8..2ac425955a555 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp @@ -71,7 +71,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_int8_colmajorA_colmajorB_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_int8_colmajorA_colmajorB_impl.hpp index 6111f503007f5..d2081f01ec167 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_int8_colmajorA_colmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_int8_colmajorA_colmajorB_impl.hpp @@ -64,7 +64,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (sg_starty / SG_SZ * TN) * K + k * TK, K); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_int8_vnni_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_int8_vnni_impl.hpp index b6fe3f0376ffd..f4f4d682930a4 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_int8_vnni_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_int8_vnni_impl.hpp @@ -65,7 +65,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK) * N + sg_starty / SG_SZ * TN, N); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp index 1c1c4f97819bf..008b3531a7ec8 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp @@ -58,7 +58,7 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q, unsigned int vnniFactor) { joint_matrix_load_checked( sg, sub_b, pB + k * N + sg_starty / SG_SZ * TN * vnniFactor, N * vnniFactor, K / vnniFactor, N * vnniFactor); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } // bounds-checked store where width and height are added joint_matrix_store_checked( diff --git a/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp b/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp index 048aed6341f6c..ddc0c14e32de8 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp +++ b/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp @@ -99,7 +99,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_ss_int8_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_ss_int8_impl.hpp index 5cca6572cef21..21353754b6580 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_ss_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_ss_int8_impl.hpp @@ -68,7 +68,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_su_int8_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_su_int8_impl.hpp index 397fcc9a5aa97..80dbbd1afbbc5 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_su_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_su_int8_impl.hpp @@ -72,7 +72,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_tf32_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_tf32_impl.hpp index 4d4ba0ee951e9..33ee1d69b4e35 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_tf32_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_tf32_impl.hpp @@ -81,7 +81,7 @@ void matrix_multiply(big_matrix &C, for (int i = 0; i < wi_data_b.length(); i++) { wi_data_b[i] = round_to_tf32(wi_data_b[i]); } - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } auto wi_slice_a = sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); diff --git a/sycl/test-e2e/Matrix/joint_matrix_transposeC_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_transposeC_impl.hpp index 24f6cce4cc09d..56dbf30bed4d4 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_transposeC_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_transposeC_impl.hpp @@ -55,7 +55,7 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q, unsigned int vnniFactor) { joint_matrix_load(sg, sub_b, pB + k * N + sg_starty / SG_SZ * TN * vnniFactor, N * vnniFactor); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, pC + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N, diff --git a/sycl/test-e2e/Matrix/joint_matrix_us_int8_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_us_int8_impl.hpp index 1d82f8833aba6..2bda766fc290a 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_us_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_us_int8_impl.hpp @@ -75,7 +75,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_uu_int8_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_uu_int8_impl.hpp index e400b6694e4a9..832b28ecc0562 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_uu_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_uu_int8_impl.hpp @@ -73,7 +73,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp index 80b67b14a55ac..309786a38003f 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp @@ -57,7 +57,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -88,7 +88,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -119,7 +119,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -150,7 +150,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -181,7 +181,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -212,7 +212,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp index 31f77dc55b16f..16603407d74b1 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp @@ -67,7 +67,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), N); //CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.row.row.f64(double {{.*}}, double {{.*}}, double {{.*}}, double {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.row.stride.f64.p1(ptr addrspace(1) %{{.*}}, double {{.*}}, double {{.*}}, i32 8) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -98,7 +98,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), K); //CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.col.col.f64(double {{.*}}, double {{.*}}, double {{.*}}, double {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64.p1(ptr addrspace(1) %{{.*}}, double {{.*}}, double {{.*}}, i32 8) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp index 7c24179022d55..47ddc0fb42f48 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp @@ -56,7 +56,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.row.row.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -87,7 +87,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.col.col.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -118,7 +118,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.row.row.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -149,7 +149,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.col.col.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -180,7 +180,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.row.row.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -211,7 +211,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.col.col.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp index 0e0b4ce903be2..0468f592b6427 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp @@ -56,7 +56,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -87,7 +87,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -118,7 +118,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -149,7 +149,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -180,7 +180,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -211,7 +211,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp index 575039723d56e..858c8625cc6e9 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp @@ -56,7 +56,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.row.row.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -87,7 +87,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.col.col.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -118,7 +118,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.row.row.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -149,7 +149,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.col.col.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -180,7 +180,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.row.row.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -211,7 +211,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.col.col.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp index 8ada375fff395..f47a701fe7bc6 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp @@ -88,7 +88,7 @@ int main() { get_wi_data(sg, sub_b)[i] = round_to_tf32(get_wi_data(sg, sub_b)[i]); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 {{.*}} joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -137,7 +137,7 @@ int main() { get_wi_data(sg, sub_b)[i] = round_to_tf32(get_wi_data(sg, sub_b)[i]); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp index 69bc136e79776..c6a1bda15cdcb 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp @@ -56,7 +56,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.row.row.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -87,7 +87,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.col.col.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -118,7 +118,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.row.row.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -149,7 +149,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.col.col.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -180,7 +180,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.row.row.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -211,7 +211,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.col.col.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/matrix/matrix_load_store_as.cpp b/sycl/test/check_device_code/matrix/matrix_load_store_as.cpp index e5935b8b3af47..20689495d6aa8 100644 --- a/sycl/test/check_device_code/matrix/matrix_load_store_as.cpp +++ b/sycl/test/check_device_code/matrix/matrix_load_store_as.cpp @@ -49,7 +49,7 @@ int main(void) { // B should load from global address space // CHECK: %{{.*}} = tail call spir_func noundef target("spirv.JointMatrixINTEL", i16, 16, 16, 2, 3, 1) @_Z[[#]]__spirv_JointMatrixLoadINTEL{{.*}}(ptr addrspace(1) noundef %{{.*}}, i64 noundef 32, i32 noundef 2, i32 noundef 3, i32 noundef 0) #{{.*}} joint_matrix_load(sg, tB, pB, 32); - tC = joint_matrix_mad(sg, tA, tB, tC); + joint_matrix_mad(sg, tC, tA, tB, tC); // C should store to global address space // CHECK: tail call spir_func void @_Z[[#]]__spirv_JointMatrixStoreINTEL{{.*}}(ptr addrspace(1) noundef %{{.*}}, target("spirv.JointMatrixINTEL", float, 8, 16, 3, 3, 2) noundef %{{.*}}, i64 noundef 16, i32 noundef 0, i32 noundef 3, i32 noundef 0) #{{.*}} joint_matrix_store(sg, tC, pC, 16, layout::row_major); diff --git a/sycl/test/matrix/matrix-bfloat16-test.cpp b/sycl/test/matrix/matrix-bfloat16-test.cpp index 2e0e309081464..83bfb767e7d79 100644 --- a/sycl/test/matrix/matrix-bfloat16-test.cpp +++ b/sycl/test/matrix/matrix-bfloat16-test.cpp @@ -89,7 +89,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/matrix-elemwise-ops.cpp b/sycl/test/matrix/matrix-elemwise-ops.cpp index 3205e4c346ba6..a7f7b4526dc74 100644 --- a/sycl/test/matrix/matrix-elemwise-ops.cpp +++ b/sycl/test/matrix/matrix-elemwise-ops.cpp @@ -93,7 +93,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } auto wi_data_c = sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); diff --git a/sycl/test/matrix/matrix-int8-test.cpp b/sycl/test/matrix/matrix-int8-test.cpp index f8dcc26ab1b17..dd83f6dd6242f 100644 --- a/sycl/test/matrix/matrix-int8-test.cpp +++ b/sycl/test/matrix/matrix-int8-test.cpp @@ -94,7 +94,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/matrix-tf32-test.cpp b/sycl/test/matrix/matrix-tf32-test.cpp index d6affb4067003..fb52c722d3b9f 100644 --- a/sycl/test/matrix/matrix-tf32-test.cpp +++ b/sycl/test/matrix/matrix-tf32-test.cpp @@ -92,7 +92,7 @@ void matrix_multiply(big_matrix &C, for (int i = 0; i < wi_data_b.length(); i++) { wi_data_b[i] = round_to_tf32(wi_data_b[i]); } - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, From 821fa89fc74493b0b984a472c5e3d4851d26dbf7 Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Thu, 21 Sep 2023 15:18:41 +0800 Subject: [PATCH 12/23] testcase changes: ext_intel_layout --- sycl/test-e2e/Matrix/element_wise_abc_impl.hpp | 2 +- .../element_wise_all_ops_int8_packed_impl.hpp | 10 +++++----- .../element_wise_irreg_sum_rows_impl.hpp | 2 +- sycl/test-e2e/Matrix/element_wise_ops_impl.hpp | 2 +- .../Matrix/elemwise_irreg_size_ops_bf16.cpp | 2 +- .../Matrix/get_coord_int8_matB_impl.hpp | 2 +- .../Matrix/joint_matrix_all_sizes_impl.hpp | 2 +- .../joint_matrix_bf16_fill_k_cache_impl.hpp | 18 +++++++++--------- .../joint_matrix_bfloat16_32x64_impl.hpp | 2 +- .../joint_matrix_bfloat16_array_impl.hpp | 2 +- .../Matrix/joint_matrix_bfloat16_impl.hpp | 2 +- .../test-e2e/Matrix/joint_matrix_half_impl.hpp | 2 +- .../Matrix/joint_matrix_out_bounds_impl.hpp | 2 +- .../Matrix/joint_matrix_query_default.cpp | 2 +- .../Matrix/joint_matrix_ss_int8_impl.hpp | 2 +- .../Matrix/joint_matrix_su_int8_impl.hpp | 2 +- .../Matrix/joint_matrix_transposeC_impl.hpp | 2 +- .../Matrix/joint_matrix_us_int8_impl.hpp | 2 +- .../Matrix/joint_matrix_uu_int8_impl.hpp | 2 +- .../matrix/matrix_load_store_as.cpp | 2 +- .../matrix-bfloat16-test-coord-basicB.cpp | 2 +- sycl/test/matrix/matrix-bfloat16-test.cpp | 2 +- sycl/test/matrix/matrix-elemwise-ops.cpp | 2 +- sycl/test/matrix/matrix-int8-test.cpp | 2 +- sycl/test/matrix/query-use.cpp | 4 ++-- 25 files changed, 38 insertions(+), 38 deletions(-) diff --git a/sycl/test-e2e/Matrix/element_wise_abc_impl.hpp b/sycl/test-e2e/Matrix/element_wise_abc_impl.hpp index 8b7ee3af2b9c5..3efc1c547c169 100644 --- a/sycl/test-e2e/Matrix/element_wise_abc_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_abc_impl.hpp @@ -56,7 +56,7 @@ void matrix_elem_wise_ops(big_matrix &C, big_matrix &A, joint_matrix sub_a; // For B, we assume B has been already VNNIed. joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_int8_packed_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_int8_packed_impl.hpp index ce89a04b4168c..f7f8f305ac19b 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_int8_packed_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_int8_packed_impl.hpp @@ -37,7 +37,7 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, sub_group sg = spmd_item.get_sub_group(); joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix_fill(sg, sub_b, 5); @@ -74,7 +74,7 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, sub_group sg = spmd_item.get_sub_group(); joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix_fill(sg, sub_b, 5); @@ -111,7 +111,7 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, sub_group sg = spmd_item.get_sub_group(); joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix_fill(sg, sub_b, 5); @@ -148,7 +148,7 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, sub_group sg = spmd_item.get_sub_group(); joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix_fill(sg, sub_b, 4); @@ -185,7 +185,7 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, sub_group sg = spmd_item.get_sub_group(); joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix_fill(sg, sub_b, 5); diff --git a/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp b/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp index cfce95cba269f..d412033e684e7 100644 --- a/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp @@ -45,7 +45,7 @@ void matrix_sum_rows(queue q, big_matrix &B, nd_range<2> &r) { sycl::sub_group sg = spmd_item.get_sub_group(); joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix_load( diff --git a/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp b/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp index 66fb8237c2be3..1a92343a27558 100644 --- a/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp @@ -52,7 +52,7 @@ void matrix_multiply(big_matrix &C, sub_a; // For B, we assume B has been already VNNIed. joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; diff --git a/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp b/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp index a74288c71ae16..11db7a04e5295 100644 --- a/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp +++ b/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp @@ -81,7 +81,7 @@ void matrix_multiply(big_matrix &C, sub_a; // For B, we assume B has been already VNNIed. joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; joint_matrix_load( diff --git a/sycl/test-e2e/Matrix/get_coord_int8_matB_impl.hpp b/sycl/test-e2e/Matrix/get_coord_int8_matB_impl.hpp index 414254dc669d4..eb1785e3da44d 100644 --- a/sycl/test-e2e/Matrix/get_coord_int8_matB_impl.hpp +++ b/sycl/test-e2e/Matrix/get_coord_int8_matB_impl.hpp @@ -113,7 +113,7 @@ void matrix_sum_cols(queue q, big_matrix &B, nd_range<2> &r) { sycl::sub_group sg = spmd_item.get_sub_group(); joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix_load( diff --git a/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp index dbd88e607f243..b8b660fb33ee2 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp @@ -57,7 +57,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, joint_matrix sub_a; // For B, we assume B has been already VNNIed. joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; diff --git a/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp index f2d359cfe130b..94ab1d07646e1 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp @@ -148,35 +148,35 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i) { ; joint_matrix + layout::ext_intel_packed> tB[NCACHE1 / tN][KCACHE2 / KCACHE1] #ifdef INIT_LIST = { joint_matrix< sub_group, TOperand, use::b, tK, tN, - ext::intel::experimental::matrix::layout::packed>(), + layout::ext_intel_packed>(), joint_matrix< sub_group, TOperand, use::b, tK, tN, - ext::intel::experimental::matrix::layout::packed>(), + layout::ext_intel_packed>(), joint_matrix< sub_group, TOperand, use::b, tK, tN, - ext::intel::experimental::matrix::layout::packed>(), + layout::ext_intel_packed>(), joint_matrix< sub_group, TOperand, use::b, tK, tN, - ext::intel::experimental::matrix::layout::packed>(), + layout::ext_intel_packed>(), joint_matrix< sub_group, TOperand, use::b, tK, tN, - ext::intel::experimental::matrix::layout::packed>(), + layout::ext_intel_packed>(), joint_matrix< sub_group, TOperand, use::b, tK, tN, - ext::intel::experimental::matrix::layout::packed>(), + layout::ext_intel_packed>(), joint_matrix< sub_group, TOperand, use::b, tK, tN, - ext::intel::experimental::matrix::layout::packed>(), + layout::ext_intel_packed>(), joint_matrix< sub_group, TOperand, use::b, tK, tN, - ext::intel::experimental::matrix::layout::packed>(), + layout::ext_intel_packed>(), } #endif ; diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64_impl.hpp index 264961a05ad96..40cc2ad58bdc6 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64_impl.hpp @@ -47,7 +47,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, sub_a; // For B, we assume B has been already VNNIed. joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_array_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_array_impl.hpp index d1c9939318551..671cf78b660a1 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_array_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_array_impl.hpp @@ -60,7 +60,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, // For B, we assume B has been already VNNIed. joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c[JM_ARRAY_SZ]; diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp index 4889e77812d72..ddf731fba24a3 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp @@ -46,7 +46,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, sub_a; // For B, we assume B has been already VNNIed. joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; diff --git a/sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp index 2ac425955a555..c7a09229063eb 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp @@ -51,7 +51,7 @@ void matrix_multiply(big_matrix &C, sub_a; // For B, we assume B has been already VNNIed. joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; diff --git a/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp index 008b3531a7ec8..51ea6745a8174 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp @@ -44,7 +44,7 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q, unsigned int vnniFactor) { // For B, since current implementation does not support non-packed // layout, users need to specify the packed_b layout. joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; // bounds-checked load where width and height are added diff --git a/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp b/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp index ddc0c14e32de8..ef5f702b4356c 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp +++ b/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp @@ -78,7 +78,7 @@ void matrix_multiply(big_matrix &C, myparams2::joint_matrix_a sub_a; myparams2::joint_matrix_b< - sub_group, ext::intel::experimental::matrix::layout::packed> + sub_group, layout::ext_intel_packed> sub_b; myparams2::joint_matrix_accumulator sub_c; diff --git a/sycl/test-e2e/Matrix/joint_matrix_ss_int8_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_ss_int8_impl.hpp index 21353754b6580..8135897f893f9 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_ss_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_ss_int8_impl.hpp @@ -52,7 +52,7 @@ void matrix_multiply(big_matrix &C, sub_a; // For B, we assume B has been already VNNIed. joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; diff --git a/sycl/test-e2e/Matrix/joint_matrix_su_int8_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_su_int8_impl.hpp index 80dbbd1afbbc5..2730f0f6184de 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_su_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_su_int8_impl.hpp @@ -52,7 +52,7 @@ void matrix_multiply(big_matrix &C, sub_a; // For B, we assume B has been already VNNIed. joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; diff --git a/sycl/test-e2e/Matrix/joint_matrix_transposeC_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_transposeC_impl.hpp index 56dbf30bed4d4..02bad19d0d4f4 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_transposeC_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_transposeC_impl.hpp @@ -43,7 +43,7 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q, unsigned int vnniFactor) { // For B, since current implementation does not support non-packed // layout, users need to specify the packed_b layout. joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; joint_matrix_load(sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_us_int8_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_us_int8_impl.hpp index 2bda766fc290a..47c9d82e18479 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_us_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_us_int8_impl.hpp @@ -54,7 +54,7 @@ void matrix_multiply(big_matrix &C, sub_a; // For B, we assume B has been already VNNIed. joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; diff --git a/sycl/test-e2e/Matrix/joint_matrix_uu_int8_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_uu_int8_impl.hpp index 832b28ecc0562..c132aeafef9d2 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_uu_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_uu_int8_impl.hpp @@ -52,7 +52,7 @@ void matrix_multiply(big_matrix &C, sub_a; // For B, we assume B has been already VNNIed. joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; diff --git a/sycl/test/check_device_code/matrix/matrix_load_store_as.cpp b/sycl/test/check_device_code/matrix/matrix_load_store_as.cpp index 20689495d6aa8..22c8203444ab4 100644 --- a/sycl/test/check_device_code/matrix/matrix_load_store_as.cpp +++ b/sycl/test/check_device_code/matrix/matrix_load_store_as.cpp @@ -30,7 +30,7 @@ int main(void) { layout::row_major> tA; joint_matrix + layout::ext_intel_packed> tB; joint_matrix tC; diff --git a/sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp b/sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp index b141f2176971d..0823244cd1dc5 100644 --- a/sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp +++ b/sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp @@ -155,7 +155,7 @@ void matrix_sum_cols(queue q, big_matrix &B, nd_range<2> &r) { // TK = 32, TN = 16 joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix_load( diff --git a/sycl/test/matrix/matrix-bfloat16-test.cpp b/sycl/test/matrix/matrix-bfloat16-test.cpp index 83bfb767e7d79..37dc5a1607631 100644 --- a/sycl/test/matrix/matrix-bfloat16-test.cpp +++ b/sycl/test/matrix/matrix-bfloat16-test.cpp @@ -68,7 +68,7 @@ void matrix_multiply(big_matrix &C, // the packed_b layout. By default, the layout is row_major and size // is (TK, TN). joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; diff --git a/sycl/test/matrix/matrix-elemwise-ops.cpp b/sycl/test/matrix/matrix-elemwise-ops.cpp index a7f7b4526dc74..d5ec19ea0a096 100644 --- a/sycl/test/matrix/matrix-elemwise-ops.cpp +++ b/sycl/test/matrix/matrix-elemwise-ops.cpp @@ -69,7 +69,7 @@ void matrix_multiply(big_matrix &C, // the packed_b layout. By default, the layout is row_major and size // is (TK, TN). joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; diff --git a/sycl/test/matrix/matrix-int8-test.cpp b/sycl/test/matrix/matrix-int8-test.cpp index dd83f6dd6242f..c4ab58c1deaec 100644 --- a/sycl/test/matrix/matrix-int8-test.cpp +++ b/sycl/test/matrix/matrix-int8-test.cpp @@ -74,7 +74,7 @@ void matrix_multiply(big_matrix &C, // the packed_b layout. By default, the layout is row_major and size // is (TK, TN). joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; diff --git a/sycl/test/matrix/query-use.cpp b/sycl/test/matrix/query-use.cpp index 9afc8e1173043..05f62c093fb28 100644 --- a/sycl/test/matrix/query-use.cpp +++ b/sycl/test/matrix/query-use.cpp @@ -64,7 +64,7 @@ void query_amx() { sub_group sg = spmd_item.get_sub_group(); myparams2::joint_matrix_a sub_a1; myparams2::joint_matrix_b< - sub_group, sycl::ext::intel::experimental::matrix::layout::packed> + sub_group, layout::ext_intel_packed> sub_b1; myparams2::joint_matrix_accumulator sub_c1; @@ -144,7 +144,7 @@ void query_xmx8() { sub_group sg = spmd_item.get_sub_group(); myparams2::joint_matrix_a sub_a1; myparams2::joint_matrix_b< - sub_group, sycl::ext::intel::experimental::matrix::layout::packed> + sub_group, layout::ext_intel_packed> sub_b1; myparams2::joint_matrix_accumulator sub_c1; From a3921b52b62f8cf8b8e6da131da7ad8d2264ed88 Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Thu, 21 Sep 2023 20:03:57 +0800 Subject: [PATCH 13/23] testcase changes: wi_data=>jm_apply --- .../test-e2e/Matrix/element_wise_abc_impl.hpp | 18 ++----- .../Matrix/element_wise_all_ops_half_impl.hpp | 51 +++++++------------ .../Matrix/element_wise_all_ops_impl.hpp | 13 +---- .../Matrix/element_wise_all_ops_int8_impl.hpp | 41 ++++----------- .../element_wise_all_ops_int8_packed_impl.hpp | 41 ++++----------- .../Matrix/element_wise_all_sizes_impl.hpp | 44 ++++++++-------- .../element_wise_irreg_sum_rows_impl.hpp | 2 +- .../test-e2e/Matrix/element_wise_ops_impl.hpp | 6 +-- .../Matrix/elemwise_irreg_size_ops_bf16.cpp | 6 +-- .../Matrix/get_coord_float_matC_impl.hpp | 12 ++--- .../Matrix/get_coord_int8_matA_impl.hpp | 13 ++--- .../Matrix/get_coord_int8_matB_impl.hpp | 2 +- .../Matrix/joint_matrix_down_convert_impl.hpp | 10 +--- .../Matrix/joint_matrix_tf32_impl.hpp | 10 ++-- .../matrix-bfloat16-test-coord-basicB.cpp | 24 ++++----- sycl/test/matrix/matrix-elemwise-ops.cpp | 6 +-- sycl/test/matrix/matrix-tf32-test.cpp | 7 +-- 17 files changed, 92 insertions(+), 214 deletions(-) diff --git a/sycl/test-e2e/Matrix/element_wise_abc_impl.hpp b/sycl/test-e2e/Matrix/element_wise_abc_impl.hpp index 3efc1c547c169..37c2a93554eec 100644 --- a/sycl/test-e2e/Matrix/element_wise_abc_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_abc_impl.hpp @@ -65,33 +65,21 @@ void matrix_elem_wise_ops(big_matrix &C, big_matrix &A, accA.template get_multi_ptr() + (sg_startx * TM) * K, K); - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] += 1; - } + joint_matrix_apply(sg, sub_a, [](T2 &x) { x += 1; }); joint_matrix_load( sg, sub_b, accB.template get_multi_ptr() + sg_starty / SG_SZ * TN * vnniFactor, N * vnniFactor); - auto wi_slice_b = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); - for (int i = 0; i < wi_slice_b.length(); i++) { - wi_slice_b[i] += 1; - } + joint_matrix_apply(sg, sub_b, [](T2 &x) { x += 1; }); joint_matrix_load( sg, sub_c, accC.template get_multi_ptr() + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N, layout::row_major); - auto wi_slice_c = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); - for (int i = 0; i < wi_slice_c.length(); i++) { - wi_slice_c[i] += 1; - } + joint_matrix_apply(sg, sub_c, [](T1 &x) { x += 1; }); }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_half_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_half_impl.hpp index 540d75c245815..42e1afb4d69f1 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_half_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_half_impl.hpp @@ -41,11 +41,8 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] = wi_slice_a[i] + static_cast(2); - } + joint_matrix_apply(sg, sub_a, + [=](T &x) { x = x + static_cast(2); }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + @@ -76,11 +73,8 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] = wi_slice_a[i] - static_cast(2); - } + joint_matrix_apply(sg, sub_a, + [=](T &x) { x = x - static_cast(2); }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + @@ -111,11 +105,8 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] = wi_slice_a[i] * static_cast(3.0); - } + joint_matrix_apply(sg, sub_a, + [=](T &x) { x = x * static_cast(3.0); }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + @@ -146,11 +137,8 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 4); - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] = wi_slice_a[i] / static_cast(2.0); - } + joint_matrix_apply(sg, sub_a, + [=](T &x) { x = x / static_cast(2.0); }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + @@ -181,30 +169,25 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - if (wi_slice_a[i]) { - if (wi_slice_a[i] > static_cast(2.0) || - wi_slice_a[i] >= static_cast(2.0) || - wi_slice_a[i] < static_cast(2.0) || - wi_slice_a[i] <= static_cast(2.0)) { - T val = (wi_slice_a[i] != static_cast(2.0)) - ? wi_slice_a[i] - : static_cast(2.0); + joint_matrix_apply(sg, sub_a, [](T &x) { + if (x) { + if (x > static_cast(2.0) || x >= static_cast(2.0) || + x < static_cast(2.0) || x <= static_cast(2.0)) { + T val = + (x != static_cast(2.0)) ? x : static_cast(2.0); val--; val++; - if (wi_slice_a[i] == static_cast(2.0)) { + if (x == static_cast(2.0)) { val -= 2; val *= 3; val /= 2; } else { val += 2; } - wi_slice_a[i] = val; + x = val; } } - } + }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp index 8e15488e151a0..b11d3093bf08d 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp @@ -63,12 +63,7 @@ void verify_op_a(const T l, const T r, const float ref, OP op) { layout::row_major> sub_mat; joint_matrix_fill(sg, sub_mat, l); - auto wi_slice = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_mat); - for (int i = 0; i < wi_slice.length(); i++) { - wi_slice[i] = op(wi_slice[i], r); - } - + joint_matrix_apply(sg, sub_mat, [=](T &x) { x = op(x, r); }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_mat, accessMat.template get_multi_ptr() + @@ -104,11 +99,7 @@ void verify_op_c(const T l, const T r, const float ref, OP op) { joint_matrix sub_mat; joint_matrix_fill(sg, sub_mat, l); - auto wi_slice = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_mat); - for (int i = 0; i < wi_slice.length(); i++) { - wi_slice[i] = op(wi_slice[i], r); - } + joint_matrix_apply(sg, sub_mat, [=](T &x) { x = op(x, r); }); joint_matrix_store( sg, sub_mat, diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_int8_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_int8_impl.hpp index 803ebe0addb3a..4a43d39738657 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_int8_impl.hpp @@ -40,11 +40,7 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] = wi_slice_a[i] + 2; - } + joint_matrix_apply(sg, sub_a, [](T &x) { x = x + 2; }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + @@ -75,11 +71,7 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] = wi_slice_a[i] - 2; - } + joint_matrix_apply(sg, sub_a, [](T &x) { x = x - 2; }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + @@ -110,11 +102,7 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] = wi_slice_a[i] * 3; - } + joint_matrix_apply(sg, sub_a, [](T &x) { x = x * 3; }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + @@ -145,11 +133,7 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 4); - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] = wi_slice_a[i] / 2; - } + joint_matrix_apply(sg, sub_a, [](T &x) { x = x / 2; }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + @@ -180,26 +164,23 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - if (wi_slice_a[i]) { - if (wi_slice_a[i] > 2 || wi_slice_a[i] >= 2 || - wi_slice_a[i] < 2 || wi_slice_a[i] <= 2) { - T val = (wi_slice_a[i] != 2) ? wi_slice_a[i] : 2; + joint_matrix_apply(sg, sub_a, [](T &x) { + if (x) { + if (x > 2 || x >= 2 || x < 2 || x <= 2) { + T val = (x != 2) ? x : 2; val--; val++; - if (wi_slice_a[i] == 2) { + if (x == 2) { val -= 2; val *= 3; val /= 2; } else { val += 2; } - wi_slice_a[i] = val; + x = val; } } - } + }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_int8_packed_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_int8_packed_impl.hpp index f7f8f305ac19b..e3d21a36bd6e1 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_int8_packed_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_int8_packed_impl.hpp @@ -42,11 +42,7 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_b, 5); - auto wi_slice_b = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); - for (int i = 0; i < wi_slice_b.length(); i++) { - wi_slice_b[i] = wi_slice_b[i] + 2; - } + joint_matrix_apply(sg, sub_b, [](int8_t &x) { x = x + 2; }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_b, accA.template get_multi_ptr() + @@ -79,11 +75,7 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_b, 5); - auto wi_slice_b = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); - for (int i = 0; i < wi_slice_b.length(); i++) { - wi_slice_b[i] = wi_slice_b[i] - 2; - } + joint_matrix_apply(sg, sub_b, [](int8_t &x) { x = x - 2; }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_b, accA.template get_multi_ptr() + @@ -116,11 +108,7 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_b, 5); - auto wi_slice_b = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); - for (int i = 0; i < wi_slice_b.length(); i++) { - wi_slice_b[i] = wi_slice_b[i] * 3; - } + joint_matrix_apply(sg, sub_b, [](int8_t &x) { x = x * 3; }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_b, accA.template get_multi_ptr() + @@ -153,11 +141,7 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_b, 4); - auto wi_slice_b = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); - for (int i = 0; i < wi_slice_b.length(); i++) { - wi_slice_b[i] = wi_slice_b[i] / 2; - } + joint_matrix_apply(sg, sub_b, [](int8_t &x) { x = x / 2; }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_b, accA.template get_multi_ptr() + @@ -190,26 +174,23 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_b, 5); - auto wi_slice_b = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); - for (int i = 0; i < wi_slice_b.length(); i++) { - if (wi_slice_b[i]) { - if (wi_slice_b[i] > 2 || wi_slice_b[i] >= 2 || - wi_slice_b[i] < 2 || wi_slice_b[i] <= 2) { - T val = (wi_slice_b[i] != 2) ? wi_slice_b[i] : 2; + joint_matrix_apply(sg, sub_b, [](T &x) { + if (x) { + if (x > 2 || x >= 2 || x < 2 || x <= 2) { + T val = (x != 2) ? x : 2; val--; val++; - if (wi_slice_b[i] == 2) { + if (x == 2) { val -= 2; val *= 3; val /= 2; } else { val += 2; } - wi_slice_b[i] = val; + x = val; } } - } + }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_b, accA.template get_multi_ptr() + diff --git a/sycl/test-e2e/Matrix/element_wise_all_sizes_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_sizes_impl.hpp index c49f9b57e2f32..6e1b6410547ad 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_sizes_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_sizes_impl.hpp @@ -49,30 +49,26 @@ void matrix_verify_add(const T1 val1, const T1 val2, const T1 result) { q.submit([&](handler &cgh) { sycl::accessor accA{bufA, cgh, sycl::read_write}; - cgh.parallel_for(r, [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size( - SG_SZ)]] { - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx - spmd_item.get_local_id(0); - const auto sg_starty = global_idy - spmd_item.get_local_id(1); - - sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a; - - joint_matrix_fill(sg, sub_a, val1); - - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] = wi_slice_a[i] + val2; - } - - ext::intel::experimental::matrix::joint_matrix_store( - sg, sub_a, - accA.template get_multi_ptr() + - (sg_startx * TM) * K + sg_starty / SG_SZ * TK, - K); - }); // parallel for + cgh.parallel_for( + r, [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a; + + joint_matrix_fill(sg, sub_a, val1); + + joint_matrix_apply(sg, sub_a, [=](T &x) { x += val2; }); + + ext::intel::experimental::matrix::joint_matrix_store( + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + sg_starty / SG_SZ * TK, + K); + }); // parallel for }).wait(); assert_ops_ref(bufA.get_host_access(), result); } diff --git a/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp b/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp index d412033e684e7..18761986561ac 100644 --- a/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp @@ -58,7 +58,7 @@ void matrix_sum_rows(queue q, big_matrix &B, nd_range<2> &r) { int32_t sum_local_rows[M] = {0}; // 8 local rows, M total // sub_b has 32x8 elements, 32 elements per WI, 4 per WI per row auto data = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); + sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); // each WI calculates local sum of rows for (int row = 0; row < TK / 4; row++) { // there are 8 rows diff --git a/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp b/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp index 1a92343a27558..1dd9779aa0b56 100644 --- a/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp @@ -74,11 +74,7 @@ void matrix_multiply(big_matrix &C, N * 4); joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } - auto wi_slice_c = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); - for (int i = 0; i < wi_slice_c.length(); i++) { - wi_slice_c[i] *= 2; - } + joint_matrix_apply(sg, sub_c, [](int32_t &x) { x = x * 2; }); joint_matrix_store( sg, sub_c, accC.template get_multi_ptr() + diff --git a/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp b/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp index 11db7a04e5295..cc8722467d262 100644 --- a/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp +++ b/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp @@ -103,11 +103,7 @@ void matrix_multiply(big_matrix &C, N * 2); joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } - auto wi_slice_c = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); - for (int i = 0; i < wi_slice_c.length(); i++) { - wi_slice_c[i] += 5.0; - } + joint_matrix_apply(sg, sub_c, [](float &x) { x += 5.0; }); joint_matrix_store( sg, sub_c, accC.template get_multi_ptr() + diff --git a/sycl/test-e2e/Matrix/get_coord_float_matC_impl.hpp b/sycl/test-e2e/Matrix/get_coord_float_matC_impl.hpp index dea7601437742..f9d19e914e639 100644 --- a/sycl/test-e2e/Matrix/get_coord_float_matC_impl.hpp +++ b/sycl/test-e2e/Matrix/get_coord_float_matC_impl.hpp @@ -50,15 +50,11 @@ void matrix_sum_rows(big_matrix &C, float *sum_rows) { N, layout::row_major); float sum_local_rows[M] = {0}; - auto data = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); - - for (int i = 0; i < data.length(); ++i) { - auto dataItem = data[i]; - auto [row, col] = dataItem.get_coord(); - sum_local_rows[row + global_idx * TM] += dataItem; - } + ext::intel::experimental::matrix::joint_matrix_apply( + sg, sub_c, [&](float &x, size_t row, size_t col) { + sum_local_rows[row + global_idx * TM] += x; + }); for (int i = 0; i < M; i++) { sum_local_rows[i] = reduce_over_group(sg, sum_local_rows[i], sycl::plus<>()); diff --git a/sycl/test-e2e/Matrix/get_coord_int8_matA_impl.hpp b/sycl/test-e2e/Matrix/get_coord_int8_matA_impl.hpp index ec21cfa036807..619f97969b29c 100644 --- a/sycl/test-e2e/Matrix/get_coord_int8_matA_impl.hpp +++ b/sycl/test-e2e/Matrix/get_coord_int8_matA_impl.hpp @@ -96,16 +96,11 @@ void matrix_sum_rows(queue q, big_matrix &A, nd_range<2> &r) { K); int32_t sum_local_rows[M] = {0}; - auto data = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - - // each WI calculates local sum of rows - for (int i = 0; i < data.length(); ++i) { - auto data_item = data[i]; - auto [row, col] = data_item.get_coord(); - sum_local_rows[row + global_idx * TM] += data_item; - } + ext::intel::experimental::matrix::joint_matrix_apply( + sg, sub_a, [&](int8_t &x, size_t row, size_t col) { + sum_local_rows[row + global_idx * TM] += x; + }); for (int i = 0; i < M; ++i) { sum_local_rows[i] = reduce_over_group(sg, sum_local_rows[i], sycl::plus<>()); diff --git a/sycl/test-e2e/Matrix/get_coord_int8_matB_impl.hpp b/sycl/test-e2e/Matrix/get_coord_int8_matB_impl.hpp index eb1785e3da44d..26295e8a5050e 100644 --- a/sycl/test-e2e/Matrix/get_coord_int8_matB_impl.hpp +++ b/sycl/test-e2e/Matrix/get_coord_int8_matB_impl.hpp @@ -124,7 +124,7 @@ void matrix_sum_cols(queue q, big_matrix &B, nd_range<2> &r) { int32_t sum_local_cols[N] = {0}; auto wiData = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); + sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); // each WI calculates local sum of cols for (int i = 0; i < wiData.length(); ++i) { diff --git a/sycl/test-e2e/Matrix/joint_matrix_down_convert_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_down_convert_impl.hpp index 6972e3854c8e8..68e9d3c145675 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_down_convert_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_down_convert_impl.hpp @@ -48,15 +48,7 @@ void matrix_copy(big_matrix &C, big_matrix &A) { accC.template get_multi_ptr() + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N, layout::row_major); - // This will be replaced by joint_matrix_copy API - // joint_matrix_copy(sg, sub_c, sub_ac); - auto wi_slice_c = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_c.length(); i++) { - wi_slice_a[i] = (bfloat16)wi_slice_c[i]; - } + joint_matrix_copy(sg, sub_c, sub_a); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + diff --git a/sycl/test-e2e/Matrix/joint_matrix_tf32_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_tf32_impl.hpp index 33ee1d69b4e35..607aba535c74f 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_tf32_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_tf32_impl.hpp @@ -76,15 +76,11 @@ void matrix_multiply(big_matrix &C, // function will work on truncated floats. joint_matrix_apply(sg, sub_a, [=](float x) { x = round_to_tf32(x); }); - auto wi_data_b = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); - for (int i = 0; i < wi_data_b.length(); i++) { - wi_data_b[i] = round_to_tf32(wi_data_b[i]); - } + joint_matrix_apply(sg, sub_b, + [=](float x) { x = round_to_tf32(x); }); joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); + joint_matrix_apply(sg, sub_a, [=](float x) { x *= 2; }); joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp b/sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp index 0823244cd1dc5..dd715887e728b 100644 --- a/sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp +++ b/sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp @@ -166,8 +166,6 @@ void matrix_sum_cols(queue q, big_matrix &B, nd_range<2> &r) { int32_t sum_local_cols[N] = {0}; // 4 local cols, N total // sub_b has 32x16 elements, 32 elements per WI, 4 per WI per row - auto wiData = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); size_t global_index; // Index into the result array that holds the sums. @@ -175,19 +173,15 @@ void matrix_sum_cols(queue q, big_matrix &B, nd_range<2> &r) { // Keep track of cols handled in this WI int32_t handled_cols[N] = {-1}; - // each WI calculates local sum of cols - for (int i = 0; i < wiData.length(); ++i) { - // get the index of the element in the submatrix - auto dataItem = wiData[i]; - auto [row, col] = dataItem.get_coord(); - - // Calculation of global index - int sg_idx = (int)global_idy / SG_SZ; - global_index = col + sg_idx * 4 /*VNNI_FACTOR*/ * SG_SZ; - sum_local_cols[global_index] += wiData[i]; - handled_cols[global_index] = 1; - } - + sycl::ext::intel::experimental::matrix::joint_matrix_apply( + sg, sub_b, + [&](int8_t &x, size_t row, + size_t col) { // Calculation of global index + int sg_idx = (int)global_idy / SG_SZ; + global_index = col + sg_idx * 4 /*VNNI_FACTOR*/ * SG_SZ; + sum_local_cols[global_index] += x; + handled_cols[global_index] = 1; + }); for (int j = 0; j < N; j++) { if (handled_cols[j] == 1) { global_index = j; diff --git a/sycl/test/matrix/matrix-elemwise-ops.cpp b/sycl/test/matrix/matrix-elemwise-ops.cpp index d5ec19ea0a096..9621f570cf461 100644 --- a/sycl/test/matrix/matrix-elemwise-ops.cpp +++ b/sycl/test/matrix/matrix-elemwise-ops.cpp @@ -95,11 +95,7 @@ void matrix_multiply(big_matrix &C, N * 4); joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } - auto wi_data_c = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); - for (int i = 0; i < wi_data_c.length(); i++) { - wi_data_c[i] *= 2; - } + joint_matrix_apply(sg, sub_c, [](int32_t &x) { x *= 2; }); joint_matrix_store( sg, sub_c, accC.template get_multi_ptr() + diff --git a/sycl/test/matrix/matrix-tf32-test.cpp b/sycl/test/matrix/matrix-tf32-test.cpp index fb52c722d3b9f..496af7dabd335 100644 --- a/sycl/test/matrix/matrix-tf32-test.cpp +++ b/sycl/test/matrix/matrix-tf32-test.cpp @@ -87,11 +87,8 @@ void matrix_multiply(big_matrix &C, // function will work on truncated floats. joint_matrix_apply(sg, sub_a, [=](float x) { x = round_to_tf32(x); }); - auto wi_data_b = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); - for (int i = 0; i < wi_data_b.length(); i++) { - wi_data_b[i] = round_to_tf32(wi_data_b[i]); - } + joint_matrix_apply(sg, sub_b, + [=](float &x) { x = round_to_tf32(x); }); joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( From ef1bc6764e2783e5d0bf06b3a09684310352861f Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Thu, 21 Sep 2023 20:06:10 +0800 Subject: [PATCH 14/23] lint --- .../test-e2e/Matrix/element_wise_abc_impl.hpp | 3 +- .../element_wise_irreg_sum_rows_impl.hpp | 6 +- .../Matrix/get_coord_int8_matB_impl.hpp | 76 +++++++++---------- .../Matrix/joint_matrix_all_sizes_impl.hpp | 3 +- .../Matrix/joint_matrix_apply_cuda.hpp | 63 ++++++++------- .../joint_matrix_bf16_fill_k_cache_impl.hpp | 40 ++++------ .../Matrix/joint_matrix_query_default.cpp | 4 +- sycl/test/matrix/query-use.cpp | 8 +- 8 files changed, 91 insertions(+), 112 deletions(-) diff --git a/sycl/test-e2e/Matrix/element_wise_abc_impl.hpp b/sycl/test-e2e/Matrix/element_wise_abc_impl.hpp index 37c2a93554eec..378c46c4b84d5 100644 --- a/sycl/test-e2e/Matrix/element_wise_abc_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_abc_impl.hpp @@ -55,8 +55,7 @@ void matrix_elem_wise_ops(big_matrix &C, big_matrix &A, sub_group sg = spmd_item.get_sub_group(); joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix + joint_matrix sub_b; joint_matrix sub_c; diff --git a/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp b/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp index 18761986561ac..683ad694fe26a 100644 --- a/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp @@ -44,8 +44,7 @@ void matrix_sum_rows(queue q, big_matrix &B, nd_range<2> &r) { sycl::sub_group sg = spmd_item.get_sub_group(); - joint_matrix + joint_matrix sub_b; joint_matrix_load( @@ -57,8 +56,7 @@ void matrix_sum_rows(queue q, big_matrix &B, nd_range<2> &r) { // (tK/4) int32_t sum_local_rows[M] = {0}; // 8 local rows, M total // sub_b has 32x8 elements, 32 elements per WI, 4 per WI per row - auto data = - sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); + auto data = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); // each WI calculates local sum of rows for (int row = 0; row < TK / 4; row++) { // there are 8 rows diff --git a/sycl/test-e2e/Matrix/get_coord_int8_matB_impl.hpp b/sycl/test-e2e/Matrix/get_coord_int8_matB_impl.hpp index 26295e8a5050e..0df698e3ace48 100644 --- a/sycl/test-e2e/Matrix/get_coord_int8_matB_impl.hpp +++ b/sycl/test-e2e/Matrix/get_coord_int8_matB_impl.hpp @@ -103,45 +103,43 @@ void matrix_sum_cols(queue q, big_matrix &B, nd_range<2> &r) { auto accB = bufB.get_access(cgh); auto v = sum_cols_v.get_access(cgh); - cgh.parallel_for( - r, [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx - spmd_item.get_local_id(0); - const auto sg_starty = global_idy - spmd_item.get_local_id(1); - - sycl::sub_group sg = spmd_item.get_sub_group(); - - joint_matrix - sub_b; - - joint_matrix_load( - sg, sub_b, - accB.template get_multi_ptr() + - (sg_startx * (TK / VF) * N) + sg_starty / SG_SZ * TN * VF, - N); - - int32_t sum_local_cols[N] = {0}; - auto wiData = - sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); - - // each WI calculates local sum of cols - for (int i = 0; i < wiData.length(); ++i) { - // get the index of the element in the submatrix - auto dataItem = wiData[i]; - auto [row, col] = dataItem.get_coord(); - size_t global_index = col + global_idy / SG_SZ * TN * VF; - sum_local_cols[global_index] += dataItem; - } - - for (int i = 0; i < N; i++) { - sum_local_cols[i] = - reduce_over_group(sg, sum_local_cols[i], sycl::plus<>()); - if (global_idy % SG_SZ == 0) - atomic_fetch_add(v[i], sum_local_cols[i]); - } - }); // parallel for + cgh.parallel_for(r, [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size( + SG_SZ)]] { + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + sycl::sub_group sg = spmd_item.get_sub_group(); + + joint_matrix + sub_b; + + joint_matrix_load(sg, sub_b, + accB.template get_multi_ptr() + + (sg_startx * (TK / VF) * N) + + sg_starty / SG_SZ * TN * VF, + N); + + int32_t sum_local_cols[N] = {0}; + auto wiData = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); + + // each WI calculates local sum of cols + for (int i = 0; i < wiData.length(); ++i) { + // get the index of the element in the submatrix + auto dataItem = wiData[i]; + auto [row, col] = dataItem.get_coord(); + size_t global_index = col + global_idy / SG_SZ * TN * VF; + sum_local_cols[global_index] += dataItem; + } + + for (int i = 0; i < N; i++) { + sum_local_cols[i] = + reduce_over_group(sg, sum_local_cols[i], sycl::plus<>()); + if (global_idy % SG_SZ == 0) + atomic_fetch_add(v[i], sum_local_cols[i]); + } + }); // parallel for }).wait(); sum_cols_ref(bufB.get_host_access(), sum_cols_v.get_host_access()); } diff --git a/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp index b8b660fb33ee2..00149e0b55ce4 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp @@ -56,8 +56,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, sub_group sg = spmd_item.get_sub_group(); joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix + joint_matrix sub_b; joint_matrix sub_c; diff --git a/sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp b/sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp index 754e6429d0d96..e091442b84cb7 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp @@ -51,34 +51,34 @@ void matrix_verify_lambda(queue q, q.submit([&](handler &cgh) { accessor accC(bufC, cgh); - cgh.parallel_for>(r, [ - accC, lambda - ](nd_item<2> spmd_item)[[sycl::reqd_sub_group_size(SG_SZ)]] { - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx - spmd_item.get_local_id(0); - const auto sg_starty = global_idy - spmd_item.get_local_id(1); - - auto sg = spmd_item.get_sub_group(); - - joint_matrix sub_a; - joint_matrix sub_b; - joint_matrix sub_c; - - joint_matrix_fill(sg, sub_a, 3); - joint_matrix_fill(sg, sub_b, 1); - joint_matrix_fill(sg, sub_c, -80); - - joint_matrix_apply(sg, sub_a, lambda); - - joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); - - joint_matrix_store( - sg, sub_c, - accC.template get_multi_ptr() + - (sg_startx * M) * (N * nWGperDim) + sg_starty / SG_SZ * N, - (N * nWGperDim), layout::row_major); - }); // parallel for + cgh.parallel_for>( + r, [accC, lambda]( + nd_item<2> spmd_item) [[sycl::reqd_sub_group_size(SG_SZ)]] { + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + auto sg = spmd_item.get_sub_group(); + + joint_matrix sub_a; + joint_matrix sub_b; + joint_matrix sub_c; + + joint_matrix_fill(sg, sub_a, 3); + joint_matrix_fill(sg, sub_b, 1); + joint_matrix_fill(sg, sub_c, -80); + + joint_matrix_apply(sg, sub_a, lambda); + + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); + + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * M) * (N * nWGperDim) + sg_starty / SG_SZ * N, + (N * nWGperDim), layout::row_major); + }); // parallel for }); } assert_ref(C.get_data(), ref); @@ -111,8 +111,8 @@ void matrix_verify_op(queue q, big_matrix &C, cgh); cgh.parallel_for>( - r, [ accC, - Op ](nd_item<2> spmd_item)[[sycl::reqd_sub_group_size(SG_SZ)]] { + r, [accC, + Op](nd_item<2> spmd_item) [[sycl::reqd_sub_group_size(SG_SZ)]] { const auto global_idx = spmd_item.get_global_id(0); const auto global_idy = spmd_item.get_global_id(1); const auto sg_startx = global_idx - spmd_item.get_local_id(0); @@ -162,8 +162,7 @@ void matrix_verify_op(queue q, big_matrix &C, (sg_startx * M) * (N * nWGperDim) + sg_starty / SG_SZ * N, (N * nWGperDim), layout::row_major); }); // parallel for - }) - .wait(); + }).wait(); } assert_ops_ref(C.get_data(), ref); } diff --git a/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp index 94ab1d07646e1..b9edadad461a6 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp @@ -153,30 +153,22 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i) { #ifdef INIT_LIST = { - joint_matrix< - sub_group, TOperand, use::b, tK, tN, - layout::ext_intel_packed>(), - joint_matrix< - sub_group, TOperand, use::b, tK, tN, - layout::ext_intel_packed>(), - joint_matrix< - sub_group, TOperand, use::b, tK, tN, - layout::ext_intel_packed>(), - joint_matrix< - sub_group, TOperand, use::b, tK, tN, - layout::ext_intel_packed>(), - joint_matrix< - sub_group, TOperand, use::b, tK, tN, - layout::ext_intel_packed>(), - joint_matrix< - sub_group, TOperand, use::b, tK, tN, - layout::ext_intel_packed>(), - joint_matrix< - sub_group, TOperand, use::b, tK, tN, - layout::ext_intel_packed>(), - joint_matrix< - sub_group, TOperand, use::b, tK, tN, - layout::ext_intel_packed>(), + joint_matrix(), + joint_matrix(), + joint_matrix(), + joint_matrix(), + joint_matrix(), + joint_matrix(), + joint_matrix(), + joint_matrix(), } #endif ; diff --git a/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp b/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp index ef5f702b4356c..7fc3547b56e31 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp +++ b/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp @@ -77,9 +77,7 @@ void matrix_multiply(big_matrix &C, sycl::sub_group sg = spmd_item.get_sub_group(); myparams2::joint_matrix_a sub_a; - myparams2::joint_matrix_b< - sub_group, layout::ext_intel_packed> - sub_b; + myparams2::joint_matrix_b sub_b; myparams2::joint_matrix_accumulator sub_c; joint_matrix_load( diff --git a/sycl/test/matrix/query-use.cpp b/sycl/test/matrix/query-use.cpp index 05f62c093fb28..fce49a83df30b 100644 --- a/sycl/test/matrix/query-use.cpp +++ b/sycl/test/matrix/query-use.cpp @@ -63,9 +63,7 @@ void query_amx() { [msize, ksize, nsize](nd_item<2> spmd_item) { sub_group sg = spmd_item.get_sub_group(); myparams2::joint_matrix_a sub_a1; - myparams2::joint_matrix_b< - sub_group, layout::ext_intel_packed> - sub_b1; + myparams2::joint_matrix_b sub_b1; myparams2::joint_matrix_accumulator sub_c1; joint_matrix sub_a; @@ -143,9 +141,7 @@ void query_xmx8() { [msize, ksize, nsize](nd_item<2> spmd_item) { sub_group sg = spmd_item.get_sub_group(); myparams2::joint_matrix_a sub_a1; - myparams2::joint_matrix_b< - sub_group, layout::ext_intel_packed> - sub_b1; + myparams2::joint_matrix_b sub_b1; myparams2::joint_matrix_accumulator sub_c1; joint_matrix sub_a; From 8f2f1971b2af532635546bf10faec7d229ef1cc6 Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Fri, 22 Sep 2023 16:25:31 +0800 Subject: [PATCH 15/23] handle cuda testcase compfail --- .../sycl/ext/oneapi/matrix/matrix-tensorcores.hpp | 12 ++++++------ .../sycl/ext/oneapi/matrix/matrix-unified.hpp | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp index 849cb676c6613..94ae318540012 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp @@ -413,7 +413,7 @@ void store_layoutT( template void joint_matrix_store_cuda( - const joint_matrix_cuda< + joint_matrix_cuda< T, sycl::ext::oneapi::experimental::matrix::use::accumulator, NumRows, NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &src, multi_ptr dst, size_t stride, @@ -482,11 +482,11 @@ void joint_matrix_mad_cuda( joint_matrix_cuda< Tc, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &D, - const joint_matrix_cuda &A, - const joint_matrix_cuda &B, - const joint_matrix_cuda< + joint_matrix_cuda &A, + joint_matrix_cuda &B, + joint_matrix_cuda< Tc, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &C) { if constexpr (M == 16 && N == 16 && K == 16) { diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp index bf2441ac17c2e..71a50663e9017 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -40,7 +40,7 @@ struct joint_matrix { #if defined(__SYCL_DEVICE_ONLY__) #if defined(__NVPTX__) - sycl::ext::oneapi::detail::joint_matrix_cuda + mutable sycl::ext::oneapi::detail::joint_matrix_cuda cuda_impl; #elif defined(__SPIR__) __spv::__spirv_JointMatrixINTEL< From 1411376b42069f36341dcd4930bade17dad45f4e Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Fri, 22 Sep 2023 16:36:55 +0800 Subject: [PATCH 16/23] address dounia's comments --- .../sycl/ext/oneapi/matrix/matrix-unified.hpp | 6 ++--- .../Matrix/element_wise_irreg_sum_rows.cpp | 26 ------------------- .../Matrix/get_coord_int8_matB_impl.hpp | 15 ++++------- sycl/test/matrix/query-use.cpp | 8 ++++-- 4 files changed, 13 insertions(+), 42 deletions(-) delete mode 100644 sycl/test-e2e/Matrix/element_wise_irreg_sum_rows.cpp diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp index 71a50663e9017..026729b32e788 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -88,8 +88,7 @@ class wi_data { return jm.cuda_impl.wi_marray.size(); #else throw runtime_error( - "get_wi_data is available using: ext::oneapi::detail::get_wi_data, but " - "intel users are expected to use joint_matrix_copy instead.", + "get_wi_data is unavailable, use joint_matrix_copy instead.", PI_ERROR_INVALID_DEVICE); #endif }; @@ -99,8 +98,7 @@ class wi_data { return (jm.cuda_impl.wi_marray[i]); #else throw runtime_error( - "get_wi_data is available using: ext::oneapi::detail::get_wi_data, but " - "intel users are expected to use joint_matrix_copy instead.", + "get_wi_data is unavailable, use joint_matrix_copy instead.", PI_ERROR_INVALID_DEVICE); #endif }; diff --git a/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows.cpp b/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows.cpp deleted file mode 100644 index 1cb48f1bc4f72..0000000000000 --- a/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows.cpp +++ /dev/null @@ -1,26 +0,0 @@ -//==-------- element_wise_irreg_sum_rows.cpp - DPC++ joint_matrix----- ----==// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// REQUIRES: matrix - -// RUN: %{build} -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 -// RUN: %{run} %t.out - -// This code calculates the sum of rows into a global array of number of rows -// elements. First, partial reduction is computed inside each SG, then atomic -// add is used to reduce between SG leaders - -#include -#include - -using namespace sycl; -using namespace sycl::ext::oneapi::experimental::matrix; - -#define SG_SZ 16 -constexpr size_t TN = 16; - -#include "element_wise_irreg_sum_rows_impl.hpp" diff --git a/sycl/test-e2e/Matrix/get_coord_int8_matB_impl.hpp b/sycl/test-e2e/Matrix/get_coord_int8_matB_impl.hpp index 0df698e3ace48..22259eb0072b0 100644 --- a/sycl/test-e2e/Matrix/get_coord_int8_matB_impl.hpp +++ b/sycl/test-e2e/Matrix/get_coord_int8_matB_impl.hpp @@ -122,16 +122,11 @@ void matrix_sum_cols(queue q, big_matrix &B, nd_range<2> &r) { N); int32_t sum_local_cols[N] = {0}; - auto wiData = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); - - // each WI calculates local sum of cols - for (int i = 0; i < wiData.length(); ++i) { - // get the index of the element in the submatrix - auto dataItem = wiData[i]; - auto [row, col] = dataItem.get_coord(); - size_t global_index = col + global_idy / SG_SZ * TN * VF; - sum_local_cols[global_index] += dataItem; - } + ext::intel::experimental::matrix::joint_matrix_apply( + sg, sub_b, [&](int8_t &x, size_t row, size_t col) { + size_t global_index = col + global_idy / SG_SZ * TN * VF; + sum_local_cols[global_index] += x; + }); for (int i = 0; i < N; i++) { sum_local_cols[i] = diff --git a/sycl/test/matrix/query-use.cpp b/sycl/test/matrix/query-use.cpp index fce49a83df30b..9afc8e1173043 100644 --- a/sycl/test/matrix/query-use.cpp +++ b/sycl/test/matrix/query-use.cpp @@ -63,7 +63,9 @@ void query_amx() { [msize, ksize, nsize](nd_item<2> spmd_item) { sub_group sg = spmd_item.get_sub_group(); myparams2::joint_matrix_a sub_a1; - myparams2::joint_matrix_b sub_b1; + myparams2::joint_matrix_b< + sub_group, sycl::ext::intel::experimental::matrix::layout::packed> + sub_b1; myparams2::joint_matrix_accumulator sub_c1; joint_matrix sub_a; @@ -141,7 +143,9 @@ void query_xmx8() { [msize, ksize, nsize](nd_item<2> spmd_item) { sub_group sg = spmd_item.get_sub_group(); myparams2::joint_matrix_a sub_a1; - myparams2::joint_matrix_b sub_b1; + myparams2::joint_matrix_b< + sub_group, sycl::ext::intel::experimental::matrix::layout::packed> + sub_b1; myparams2::joint_matrix_accumulator sub_c1; joint_matrix sub_a; From 95df3b18379e8993958bc8b6b800e345ae9099d3 Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Fri, 22 Sep 2023 16:40:38 +0800 Subject: [PATCH 17/23] lint --- sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp index 026729b32e788..1d7ef2dae065c 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -40,7 +40,8 @@ struct joint_matrix { #if defined(__SYCL_DEVICE_ONLY__) #if defined(__NVPTX__) - mutable sycl::ext::oneapi::detail::joint_matrix_cuda + mutable sycl::ext::oneapi::detail::joint_matrix_cuda cuda_impl; #elif defined(__SPIR__) __spv::__spirv_JointMatrixINTEL< From fb1afdcd5f0b26f073a2bd1a09ea3ab3eb64d75d Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Fri, 22 Sep 2023 17:01:59 +0800 Subject: [PATCH 18/23] rm sycl/test/matrix/query-use.cpp --- sycl/test/matrix/query-use.cpp | 162 --------------------------------- 1 file changed, 162 deletions(-) delete mode 100644 sycl/test/matrix/query-use.cpp diff --git a/sycl/test/matrix/query-use.cpp b/sycl/test/matrix/query-use.cpp deleted file mode 100644 index 9afc8e1173043..0000000000000 --- a/sycl/test/matrix/query-use.cpp +++ /dev/null @@ -1,162 +0,0 @@ -// RUN: %clangxx -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 -fsycl -o query-use %s -#include -#include - -using namespace sycl; -using namespace sycl::ext::oneapi::experimental::matrix; - -void query_amx() { - - // generates combination assert - // using myparams = tpu_params; - - // generates types assert - // using myparams2 = tpu_params; - - // tells whether a combination is valid or not, if valid, those will be set as - // default - using myparams = tpu_params; - - size_t dmsize = myparams::M; - size_t dnsize = myparams::N; - size_t dksize = myparams::K; - std::cout << "sizes of AMX tpu_params chosen by the user are: M " << dmsize - << " N " << dnsize << " K " << dksize << std::endl; - - // Sizes-only query: types are given, generate default sizes - using myparams2 = tpu_params; - myparams2 p; - dmsize = myparams2::M; - dnsize = myparams2::N; - dksize = myparams2::K; - std::cout << "default AMX sizes tpu_params are: M " << dmsize << " N " - << dnsize << " K " << dksize << "\n AMX int8 num combinations is " - << p.num_combinations << std::endl; - - // general query: types are not given - tpu_params myparams3; - - if (myparams3.num_scopes > 0) - if (myparams3.scopes[0] == scope_t::sub_group) - std::cout << "There are " << myparams3.num_scopes - << " Scopes that are supported by AMX implementation and " - "subgroup is one of them " - << std::endl; - - std::cout << "AMX query num combinations: " << myparams3.num_combinations - << std::endl; - - if (myparams3.combinations[0].msize != 0) // this is a max params hardware - return; - constexpr int msize = myparams3.combinations[0].max_msize; - constexpr int nsize = myparams3.combinations[0].max_nsize; - constexpr int ksize = myparams3.combinations[0].max_ksize; - std::cout << "AMX query sizes are: M " << msize << " N " << nsize << " K " - << ksize << std::endl; - - size_t NDRangeM = 1024 / msize; - size_t NDRangeN = 1024 / nsize; - queue q; - q.submit([&](handler &cgh) { - cgh.parallel_for( - nd_range<2>({NDRangeM, NDRangeN}, {1, 1}), - [msize, ksize, nsize](nd_item<2> spmd_item) { - sub_group sg = spmd_item.get_sub_group(); - myparams2::joint_matrix_a sub_a1; - myparams2::joint_matrix_b< - sub_group, sycl::ext::intel::experimental::matrix::layout::packed> - sub_b1; - myparams2::joint_matrix_accumulator sub_c1; - - joint_matrix sub_a; - joint_matrix sub_b; - joint_matrix sub_c; - }); - }); -} - -void query_xmx8() { - - // generates combination assert - // using myparams = tpu_params; - - // generate combination of type assert - // using myparams = tpu_params; - - // tells whether a combination is valid or not, if valid, those will be set as - // default - using myparams = tpu_params; - - size_t dmsize = myparams::M; - size_t dnsize = myparams::N; - size_t dksize = myparams::K; - std::cout << "sizes of XMX8 tpu_params chosen by the user are: M " << dmsize - << " N " << dnsize << " K " << dksize << std::endl; - - // sizes-only query: types are given, generate default sizes - using myparams2 = tpu_params; - myparams2 p; - dmsize = myparams2::M; - dnsize = myparams2::N; - dksize = myparams2::K; - std::cout << "Default XMX8 sizes are: M " << dmsize << " N " << dnsize - << " K " << dksize << "\n XMX8 int8 num combinations is " - << p.num_combinations << std::endl; - - dmsize = myparams2::combinations[0].msize; - dnsize = myparams2::combinations[0].nsize; - dksize = myparams2::combinations[0].ksize; - std::cout << "one of XMX8 combination sizes is: M " << dmsize << " N " - << dnsize << " K " << dksize << std::endl; - - // general query: types are not given - tpu_params myparams3; - - if (myparams3.num_scopes > 0) - if (myparams3.scopes[0] == scope_t::sub_group) - std::cout << "There are " << myparams3.num_scopes - << " Scopes that are supported by XMX8 implementation and " - "subgroup is one of them " - << std::endl; - - std::cout << "XMX8 query num combinations: " << myparams3.num_combinations - << std::endl; - - if (myparams3.combinations[0].msize == 0) // this is not a max params hardware - return; - constexpr int msize = myparams3.combinations[0].msize; - constexpr int nsize = myparams3.combinations[0].nsize; - constexpr int ksize = myparams3.combinations[0].ksize; - std::cout << "XMX8 query sizes are: M " << msize << " N " << nsize << " K " - << ksize << std::endl; - std::cout << "XMX8 query max sizes are: M " - << myparams3.combinations[0].max_msize << " N " - << myparams3.combinations[0].max_nsize << " K " - << myparams3.combinations[0].max_ksize << std::endl; - - size_t NDRangeM = 1024 / msize; - size_t NDRangeN = 1024 / nsize; - queue q; - q.submit([&](handler &cgh) { - cgh.parallel_for( - nd_range<2>({NDRangeM, NDRangeN}, {1, 1}), - [msize, ksize, nsize](nd_item<2> spmd_item) { - sub_group sg = spmd_item.get_sub_group(); - myparams2::joint_matrix_a sub_a1; - myparams2::joint_matrix_b< - sub_group, sycl::ext::intel::experimental::matrix::layout::packed> - sub_b1; - myparams2::joint_matrix_accumulator sub_c1; - - joint_matrix sub_a; - joint_matrix sub_b; - joint_matrix sub_c; - }); - }); -} - -int main() { - query_amx(); - query_xmx8(); - return 0; -} From 11df5313a65eda4d46cba2c4fff052dcdda7d4be Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Mon, 25 Sep 2023 16:22:14 +0800 Subject: [PATCH 19/23] fix x jm_mad in joint_matrix_bf16_fill_k_cache_impl.hpp --- sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp index b9edadad461a6..7efdb03b25a8c 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp @@ -220,8 +220,8 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i) { for (unsigned int n = 0; n < NCACHE1 / tN; n++) { #endif - tC[m][n] = - joint_matrix_mad(sg, tA[m][k1], tB[n][k1], tC[m][n]); + joint_matrix_mad(sg, tC[m][n], tA[m][k1], tB[n][k1], + tC[m][n]); #ifdef MANUAL_UNROLL }); // n }); // m From a82110767505718b630b3de1c33171bd4ef581ec Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Thu, 12 Oct 2023 01:31:20 +0800 Subject: [PATCH 20/23] address comments --- .../sycl/ext/oneapi/matrix/matrix-intel.hpp | 10 ++-- .../sycl/ext/oneapi/matrix/matrix-unified.hpp | 14 +---- .../Matrix/element_wise_all_ops_tf32_impl.hpp | 59 +++++++------------ 3 files changed, 29 insertions(+), 54 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp index 98836da6cc7c3..b852e3f1ff3f5 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp @@ -522,12 +522,13 @@ joint_matrix_store(Group, } template + sycl::ext::oneapi::experimental::matrix::use Use, size_t Rows, + size_t Cols, sycl::ext::oneapi::experimental::matrix::layout Layout, + typename F> inline __SYCL_ALWAYS_INLINE void joint_matrix_apply( Group sg, - sycl::ext::oneapi::experimental::matrix::joint_matrix &jm, + sycl::ext::oneapi::experimental::matrix::joint_matrix &jm, F &&lambda) { #if defined(__SYCL_DEVICE_ONLY__) #if defined(__NVPTX__) @@ -554,7 +555,6 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_apply( throw runtime_error("joint matrix is not supported on host device.", PI_ERROR_INVALID_DEVICE); #endif - return; } } // namespace intel::experimental::matrix diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp index 431ddc4dd4d82..1389076c52e97 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -87,21 +87,12 @@ class wi_data { size_t length() { #if defined(__NVPTX__) return jm.cuda_impl.wi_marray.size(); -#else - throw runtime_error( - "get_wi_data is unavailable, use joint_matrix_copy instead.", - PI_ERROR_INVALID_DEVICE); #endif }; decltype(auto) operator[](size_t i) { #if defined(__NVPTX__) return (jm.cuda_impl.wi_marray[i]); -#else - std::ignore = i; - throw runtime_error( - "get_wi_data is unavailable, use joint_matrix_copy instead.", - PI_ERROR_INVALID_DEVICE); #endif }; }; @@ -129,9 +120,8 @@ template &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, round_to_tf32(5.0)); - auto wi_slice_a = - ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] = wi_slice_a[i] + 2; - } + joint_matrix_apply(sg, sub_a, + [&](float &x) { x = x + round_to_tf32(2); }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, @@ -77,11 +74,9 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, round_to_tf32(5.0)); - auto wi_slice_a = - ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] = wi_slice_a[i] - round_to_tf32(2); - } + joint_matrix_apply(sg, sub_a, + [&](float &x) { x = x - round_to_tf32(2); }); + ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + @@ -111,11 +106,8 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, joint_matrix sub_a; joint_matrix_fill(sg, sub_a, round_to_tf32(5.0)); - auto wi_slice_a = - ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] = wi_slice_a[i] * round_to_tf32(3.0); - } + joint_matrix_apply(sg, sub_a, + [&](float &x) { x = x * round_to_tf32(3.0); }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + @@ -146,11 +138,8 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, round_to_tf32(4.0)); - auto wi_slice_a = - ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] = wi_slice_a[i] / round_to_tf32(2.0); - } + joint_matrix_apply(sg, sub_a, + [&](float &x) { x = x / round_to_tf32(2); }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + @@ -180,27 +169,23 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, round_to_tf32(5.0)); - auto wi_slice_a = - ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - if (wi_slice_a[i]) { - if (wi_slice_a[i] > 2.0 || wi_slice_a[i] >= 2.0 || - wi_slice_a[i] < 2.0 || wi_slice_a[i] <= 2.0) { - Ts val = (wi_slice_a[i] != 2.0) ? wi_slice_a[i] : 2.0; - val = val - static_cast(1); - val = val + static_cast(1); - if (wi_slice_a[i] == 2.0) { - val = val - static_cast(2); - val = val * static_cast(3); - val = val / static_cast(2); - + joint_matrix_apply(sg, sub_a, [&](float &x) { + if (x) { + if (x > 2 || x >= 2 || x < 2 || x <= 2) { + float val = (x != 2) ? x : 2; + val--; + val++; + if (x == 2) { + val -= 2; + val *= 3; + val /= 2; } else { - val = val + static_cast(2); + val += 2; } - wi_slice_a[i] = val; + x = val; } } - } + }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + From 1d091de81a2cbacce18b51edbac3b656e2fe8798 Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Thu, 12 Oct 2023 02:51:16 +0800 Subject: [PATCH 21/23] rm element_wise_irreg_sum_rows_impl.hpp --- .../element_wise_irreg_sum_rows_impl.hpp | 105 ------------------ .../element_wise_irreg_sum_rows_impl.hpp | 105 ------------------ 2 files changed, 210 deletions(-) delete mode 100644 sycl/test-e2e/Matrix/Legacy/element_wise_irreg_sum_rows_impl.hpp delete mode 100644 sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp diff --git a/sycl/test-e2e/Matrix/Legacy/element_wise_irreg_sum_rows_impl.hpp b/sycl/test-e2e/Matrix/Legacy/element_wise_irreg_sum_rows_impl.hpp deleted file mode 100644 index 6a18fe3650f2c..0000000000000 --- a/sycl/test-e2e/Matrix/Legacy/element_wise_irreg_sum_rows_impl.hpp +++ /dev/null @@ -1,105 +0,0 @@ -#define TN SG_SZ -#define TK 32 - -template struct big_matrix { -public: - T *mat; - -public: - T *get_data() { return mat; } - void set_data(T *data) { mat = data; } - big_matrix(T *data) : mat(data) {} -}; - -template -void sum_rows_ref(host_accessor B, - host_accessor sum_rows) { - int sum_rows_ref[M] = {0}; - for (size_t i = 0; i < M; i++) { - for (size_t j = 0; j < N; j++) { - sum_rows_ref[i] += B[i][j]; - } - auto diff = sum_rows[i] - sum_rows_ref[i]; - assert(std::fabs(static_cast(diff)) <= - std::numeric_limits::epsilon()); - } -} - -template -void matrix_sum_rows(queue q, big_matrix &B, nd_range<2> &r) { - buffer bufB(B.get_data(), range<2>(M, N)); - // size of vector is known because SG size of set by the user in this case - int sum_rows[M] = {0}; - buffer sum_rows_v(sum_rows, M); // there are total of tK/4 * 2, 16 rows - q.submit([&](handler &cgh) { - auto accB = bufB.get_access(cgh); - - auto v = sum_rows_v.get_access(cgh); - - cgh.parallel_for( - r, [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx - spmd_item.get_local_id(0); - const auto sg_starty = global_idy - spmd_item.get_local_id(1); - - sycl::sub_group sg = spmd_item.get_sub_group(); - - joint_matrix sub_b(sg); - - joint_matrix_load( - sg, sub_b, - accB.template get_multi_ptr() + - (global_idx * (TK / 4) * N) + sg_starty / SG_SZ * TN * 4, - N, matrix_layout::packed_b); - // calculate sum of rows in sum_rows_v[8], there are 8 rows in sub_b - // (tK/4) - int32_t sum_local_rows[M] = {0}; // 8 local rows, M total - // sub_b has 32x8 elements, 32 elements per WI, 4 per WI per row - auto data = sub_b.get_wi_data(); - - // each WI calculates local sum of rows - for (int row = 0; row < TK / 4; row++) { // there are 8 rows - for (int i = 0; i < data.length() / (TK / 4); i++) { // 4 per row - // i*SG_SIZE index is found based on the round robin - // distribution we are using in the implementation - sum_local_rows[row + global_idx * (TK / 4)] += data[i + row * 4]; - } - sum_local_rows[row + global_idx * (TK / 4)] = reduce_over_group( - sg, sum_local_rows[row + global_idx * (TK / 4)], - sycl::plus<>()); - - // only Groups leader perform the global reduction - if (global_idy % SG_SZ == 0) { - atomic_fetch_add(v[row + global_idx * (TK / 4)], - sum_local_rows[row + global_idx * (TK / 4)]); - } - } - }); // parallel for - }).wait(); - sum_rows_ref(bufB.get_host_access(read_only), - sum_rows_v.get_host_access(read_only)); -} - -static constexpr size_t MATRIX_K = TK / 4 * 2; -static constexpr size_t MATRIX_N = TN * 4 * 2; -int8_t B[MATRIX_K][MATRIX_N]; - -int main() { - big_matrix MB((int8_t *)&B); - - size_t NDRangeK = MATRIX_K / (TK / 4); - size_t NDRangeN = (MATRIX_N / 4) / TN; - queue q; - nd_range<2> r({NDRangeK, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}); - - for (int i = 0; i < MATRIX_K; i++) { - for (int j = 0; j < MATRIX_N; j++) { - B[i][j] = i; - } - } - - matrix_sum_rows(q, MB, r); - - return 0; -} diff --git a/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp b/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp deleted file mode 100644 index 683ad694fe26a..0000000000000 --- a/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp +++ /dev/null @@ -1,105 +0,0 @@ -#define TK 32 - -template struct big_matrix { -public: - T *mat; - -public: - T *get_data() { return mat; } - void set_data(T *data) { mat = data; } - big_matrix(T *data) : mat(data) {} -}; - -template -void sum_rows_ref(host_accessor B, - host_accessor sum_rows) { - int sum_rows_ref[M] = {0}; - for (size_t i = 0; i < M; i++) { - for (size_t j = 0; j < N; j++) { - sum_rows_ref[i] += B[i][j]; - } - auto diff = sum_rows[i] - sum_rows_ref[i]; - assert(std::fabs(static_cast(diff)) <= - std::numeric_limits::epsilon()); - } -} - -template -void matrix_sum_rows(queue q, big_matrix &B, nd_range<2> &r) { - buffer bufB(B.get_data(), range<2>(M, N)); - // size of vector is known because SG size of set by the user in this case - int sum_rows[M] = {0}; - buffer sum_rows_v(sum_rows, M); // there are total of tK/4 * 2, 16 rows - q.submit([&](handler &cgh) { - auto accB = bufB.get_access(cgh); - - auto v = sum_rows_v.get_access(cgh); - - cgh.parallel_for( - r, [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx - spmd_item.get_local_id(0); - const auto sg_starty = global_idy - spmd_item.get_local_id(1); - - sycl::sub_group sg = spmd_item.get_sub_group(); - - joint_matrix - sub_b; - - joint_matrix_load( - sg, sub_b, - accB.template get_multi_ptr() + - (global_idx * (TK / 4) * N) + sg_starty / SG_SZ * TN * 4, - N); - // calculate sum of rows in sum_rows_v[8], there are 8 rows in sub_b - // (tK/4) - int32_t sum_local_rows[M] = {0}; // 8 local rows, M total - // sub_b has 32x8 elements, 32 elements per WI, 4 per WI per row - auto data = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); - - // each WI calculates local sum of rows - for (int row = 0; row < TK / 4; row++) { // there are 8 rows - for (int i = 0; i < data.length() / (TK / 4); i++) { // 4 per row - // i*SG_SIZE index is found based on the round robin - // distribution we are using in the implementation - sum_local_rows[row + global_idx * (TK / 4)] += data[i + row * 4]; - } - sum_local_rows[row + global_idx * (TK / 4)] = reduce_over_group( - sg, sum_local_rows[row + global_idx * (TK / 4)], - sycl::plus<>()); - - // only Groups leader perform the global reduction - if (global_idy % SG_SZ == 0) { - atomic_fetch_add(v[row + global_idx * (TK / 4)], - sum_local_rows[row + global_idx * (TK / 4)]); - } - } - }); // parallel for - }).wait(); - sum_rows_ref(bufB.get_host_access(read_only), - sum_rows_v.get_host_access(read_only)); -} - -static constexpr size_t MATRIX_K = TK / 4 * 2; -static constexpr size_t MATRIX_N = TN * 4 * 2; -int8_t B[MATRIX_K][MATRIX_N]; - -int main() { - big_matrix MB((int8_t *)&B); - - size_t NDRangeK = MATRIX_K / (TK / 4); - size_t NDRangeN = (MATRIX_N / 4) / TN; - queue q; - nd_range<2> r({NDRangeK, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}); - - for (int i = 0; i < MATRIX_K; i++) { - for (int j = 0; j < MATRIX_N; j++) { - B[i][j] = i; - } - } - - matrix_sum_rows(q, MB, r); - - return 0; -} From 1e20968f7752069c7e2e370f009e8a0499e5b1f1 Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Thu, 12 Oct 2023 02:54:50 +0800 Subject: [PATCH 22/23] small fix --- sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp index 1389076c52e97..327e1e326f108 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -93,6 +93,8 @@ class wi_data { decltype(auto) operator[](size_t i) { #if defined(__NVPTX__) return (jm.cuda_impl.wi_marray[i]); +#else + std::ignore = i; #endif }; }; From 1fe7fcdd13663619a0e7e3a911b1ea6bb72139d2 Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Thu, 12 Oct 2023 03:06:42 +0800 Subject: [PATCH 23/23] small fix --- .../element_wise_irreg_sum_rows_impl.hpp | 105 ++++++++++++++++++ .../XMX8/element_wise_irreg_sum_rows.cpp | 26 ----- 2 files changed, 105 insertions(+), 26 deletions(-) create mode 100644 sycl/test-e2e/Matrix/Legacy/element_wise_irreg_sum_rows_impl.hpp delete mode 100644 sycl/test-e2e/Matrix/XMX8/element_wise_irreg_sum_rows.cpp diff --git a/sycl/test-e2e/Matrix/Legacy/element_wise_irreg_sum_rows_impl.hpp b/sycl/test-e2e/Matrix/Legacy/element_wise_irreg_sum_rows_impl.hpp new file mode 100644 index 0000000000000..6a18fe3650f2c --- /dev/null +++ b/sycl/test-e2e/Matrix/Legacy/element_wise_irreg_sum_rows_impl.hpp @@ -0,0 +1,105 @@ +#define TN SG_SZ +#define TK 32 + +template struct big_matrix { +public: + T *mat; + +public: + T *get_data() { return mat; } + void set_data(T *data) { mat = data; } + big_matrix(T *data) : mat(data) {} +}; + +template +void sum_rows_ref(host_accessor B, + host_accessor sum_rows) { + int sum_rows_ref[M] = {0}; + for (size_t i = 0; i < M; i++) { + for (size_t j = 0; j < N; j++) { + sum_rows_ref[i] += B[i][j]; + } + auto diff = sum_rows[i] - sum_rows_ref[i]; + assert(std::fabs(static_cast(diff)) <= + std::numeric_limits::epsilon()); + } +} + +template +void matrix_sum_rows(queue q, big_matrix &B, nd_range<2> &r) { + buffer bufB(B.get_data(), range<2>(M, N)); + // size of vector is known because SG size of set by the user in this case + int sum_rows[M] = {0}; + buffer sum_rows_v(sum_rows, M); // there are total of tK/4 * 2, 16 rows + q.submit([&](handler &cgh) { + auto accB = bufB.get_access(cgh); + + auto v = sum_rows_v.get_access(cgh); + + cgh.parallel_for( + r, [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + sycl::sub_group sg = spmd_item.get_sub_group(); + + joint_matrix sub_b(sg); + + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (global_idx * (TK / 4) * N) + sg_starty / SG_SZ * TN * 4, + N, matrix_layout::packed_b); + // calculate sum of rows in sum_rows_v[8], there are 8 rows in sub_b + // (tK/4) + int32_t sum_local_rows[M] = {0}; // 8 local rows, M total + // sub_b has 32x8 elements, 32 elements per WI, 4 per WI per row + auto data = sub_b.get_wi_data(); + + // each WI calculates local sum of rows + for (int row = 0; row < TK / 4; row++) { // there are 8 rows + for (int i = 0; i < data.length() / (TK / 4); i++) { // 4 per row + // i*SG_SIZE index is found based on the round robin + // distribution we are using in the implementation + sum_local_rows[row + global_idx * (TK / 4)] += data[i + row * 4]; + } + sum_local_rows[row + global_idx * (TK / 4)] = reduce_over_group( + sg, sum_local_rows[row + global_idx * (TK / 4)], + sycl::plus<>()); + + // only Groups leader perform the global reduction + if (global_idy % SG_SZ == 0) { + atomic_fetch_add(v[row + global_idx * (TK / 4)], + sum_local_rows[row + global_idx * (TK / 4)]); + } + } + }); // parallel for + }).wait(); + sum_rows_ref(bufB.get_host_access(read_only), + sum_rows_v.get_host_access(read_only)); +} + +static constexpr size_t MATRIX_K = TK / 4 * 2; +static constexpr size_t MATRIX_N = TN * 4 * 2; +int8_t B[MATRIX_K][MATRIX_N]; + +int main() { + big_matrix MB((int8_t *)&B); + + size_t NDRangeK = MATRIX_K / (TK / 4); + size_t NDRangeN = (MATRIX_N / 4) / TN; + queue q; + nd_range<2> r({NDRangeK, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}); + + for (int i = 0; i < MATRIX_K; i++) { + for (int j = 0; j < MATRIX_N; j++) { + B[i][j] = i; + } + } + + matrix_sum_rows(q, MB, r); + + return 0; +} diff --git a/sycl/test-e2e/Matrix/XMX8/element_wise_irreg_sum_rows.cpp b/sycl/test-e2e/Matrix/XMX8/element_wise_irreg_sum_rows.cpp deleted file mode 100644 index 6559f4c93248d..0000000000000 --- a/sycl/test-e2e/Matrix/XMX8/element_wise_irreg_sum_rows.cpp +++ /dev/null @@ -1,26 +0,0 @@ -//==-------- element_wise_irreg_sum_rows.cpp - DPC++ joint_matrix----- ----==// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// REQUIRES: matrix-xmx8 - -// RUN: %{build} -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 -// RUN: %{run} %t.out - -// this code calculates the sum of rows into a global array of number of rows -// elements. First, partial reduction is computed inside each SG, then atomic -// add is used to reduce between SG leaders - -#include -#include - -using namespace sycl; -using namespace sycl::ext::oneapi::experimental::matrix; - -#define SG_SZ 8 -constexpr size_t TN = 8; - -#include "../element_wise_irreg_sum_rows_impl.hpp"