diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp index 5ee57a2ab3f5f..194dd075aeea0 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp @@ -51,34 +51,45 @@ struct joint_matrix< } // namespace experimental::matrix namespace detail { -using namespace experimental; -template +template struct joint_matrix_load_impl { - void load(matrix::joint_matrix &res, + void load(sycl::ext::oneapi::experimental::matrix::joint_matrix< + T, MT, NumRows, NumCols, Layout> &res, multi_ptr src, size_t stride); }; -template constexpr int get_layout_id(); +template +constexpr int get_layout_id(); -template <> constexpr int get_layout_id() { +template <> +constexpr int get_layout_id< + sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major>() { return 0; } -template <> constexpr int get_layout_id() { +template <> +constexpr int get_layout_id< + sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major>() { return 1; } -template +template struct joint_matrix_load_impl< - double, matrix::matrix_use::a, 8, 4, Layout, Space, - typename std::enable_if_t> { - void - load(matrix::joint_matrix &res, - multi_ptr src, size_t stride) { + double, sycl::ext::oneapi::experimental::matrix::matrix_use::a, 8, 4, + Layout, Space, + typename std::enable_if_t> { + void load(sycl::ext::oneapi::experimental::matrix::joint_matrix< + double, sycl::ext::oneapi::experimental::matrix::matrix_use::a, + 8, 4, Layout> &res, + multi_ptr src, size_t stride) { #ifdef __NVPTX__ #ifdef __SYCL_DEVICE_ONLY__ @@ -88,14 +99,19 @@ struct joint_matrix_load_impl< } }; -template +template struct joint_matrix_load_impl< - double, matrix::matrix_use::b, 4, 8, Layout, Space, - typename std::enable_if_t> { - void - load(matrix::joint_matrix &res, - multi_ptr src, size_t stride) { + double, sycl::ext::oneapi::experimental::matrix::matrix_use::b, 4, 8, + Layout, Space, + typename std::enable_if_t> { + void load(sycl::ext::oneapi::experimental::matrix::joint_matrix< + double, sycl::ext::oneapi::experimental::matrix::matrix_use::b, + 4, 8, Layout> &res, + multi_ptr src, size_t stride) { #ifdef __NVPTX__ #ifdef __SYCL_DEVICE_ONLY__ __dmma_m8n8k4_ld_b(res.data, src.get(), stride, get_layout_id()); @@ -104,14 +120,21 @@ struct joint_matrix_load_impl< } }; -template +template struct joint_matrix_load_impl< - double, matrix::matrix_use::accumulator, 8, 8, Layout, Space, - typename std::enable_if_t> { - void load(matrix::joint_matrix &res, - multi_ptr src, size_t stride) { + double, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, 8, + 8, Layout, Space, + typename std::enable_if_t> { + void + load(sycl::ext::oneapi::experimental::matrix::joint_matrix< + double, + sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, 8, + 8, Layout> &res, + multi_ptr src, size_t stride) { #ifdef __NVPTX__ #ifdef __SYCL_DEVICE_ONLY__ @@ -122,22 +145,30 @@ struct joint_matrix_load_impl< }; template + sycl::ext::oneapi::experimental::matrix::matrix_layout Layout, + access::address_space Space, typename Cond = void> struct joint_matrix_store_impl { - void store(matrix::joint_matrix &src, - multi_ptr dst, size_t stride); + void + store(sycl::ext::oneapi::experimental::matrix::joint_matrix< + T, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, + NumRows, NumCols, Layout> &src, + multi_ptr dst, size_t stride); }; -template +template struct joint_matrix_store_impl< double, 8, 8, Layout, Space, - typename std::enable_if_t> { - void store(matrix::joint_matrix &src, - multi_ptr dst, size_t stride) { + typename std::enable_if_t> { + void + store(sycl::ext::oneapi::experimental::matrix::joint_matrix< + double, + sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, 8, + 8, Layout> &src, + multi_ptr dst, size_t stride) { #ifdef __NVPTX__ #ifdef __SYCL_DEVICE_ONLY__ @@ -149,60 +180,98 @@ struct joint_matrix_store_impl< }; template + sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutA, + sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutB, + sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutC, + typename Cond = void> struct joint_matrix_mad_impl { - matrix::joint_matrix - mad(matrix::joint_matrix A, - matrix::joint_matrix B, - matrix::joint_matrix + sycl::ext::oneapi::experimental::matrix::joint_matrix< + T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, + N, LayoutC> + mad(sycl::ext::oneapi::experimental::matrix::joint_matrix< + T1, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, K, + LayoutA> + A, + sycl::ext::oneapi::experimental::matrix::joint_matrix< + T1, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, N, + LayoutB> + B, + sycl::ext::oneapi::experimental::matrix::joint_matrix< + T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, + M, N, LayoutC> C); }; -template +template constexpr int get_layout_pair_id(); template <> -constexpr int get_layout_pair_id() { +constexpr int get_layout_pair_id< + sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major, + sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major>() { return 0; } template <> -constexpr int get_layout_pair_id() { +constexpr int get_layout_pair_id< + sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major, + sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major>() { return 1; } template <> -constexpr int get_layout_pair_id() { +constexpr int get_layout_pair_id< + sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major, + sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major>() { return 2; } template <> -constexpr int get_layout_pair_id() { +constexpr int get_layout_pair_id< + sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major, + sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major>() { return 3; } -template +template struct joint_matrix_mad_impl< double, double, 8, 4, 8, LayoutA, LayoutB, LayoutC, - typename std::enable_if_t<(LayoutA == matrix::matrix_layout::row_major || - LayoutA == matrix::matrix_layout::col_major) && - (LayoutB == matrix::matrix_layout::row_major || - LayoutB == matrix::matrix_layout::col_major) && - (LayoutC == matrix::matrix_layout::row_major || - LayoutC == matrix::matrix_layout::col_major)>> { - matrix::joint_matrix - mad(matrix::joint_matrix A, - matrix::joint_matrix B, - matrix::joint_matrix + typename std::enable_if_t< + (LayoutA == sycl::ext::oneapi::experimental::matrix::matrix_layout:: + row_major || + LayoutA == sycl::ext::oneapi::experimental::matrix::matrix_layout:: + col_major) && + (LayoutB == sycl::ext::oneapi::experimental::matrix::matrix_layout:: + row_major || + LayoutB == sycl::ext::oneapi::experimental::matrix::matrix_layout:: + col_major) && + (LayoutC == sycl::ext::oneapi::experimental::matrix::matrix_layout:: + row_major || + LayoutC == sycl::ext::oneapi::experimental::matrix::matrix_layout:: + col_major)>> { + sycl::ext::oneapi::experimental::matrix::joint_matrix< + double, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, + 8, 8, LayoutC> + mad(sycl::ext::oneapi::experimental::matrix::joint_matrix< + double, sycl::ext::oneapi::experimental::matrix::matrix_use::a, 8, 4, + LayoutA> + A, + sycl::ext::oneapi::experimental::matrix::joint_matrix< + double, sycl::ext::oneapi::experimental::matrix::matrix_use::b, 4, 8, + LayoutB> + B, + sycl::ext::oneapi::experimental::matrix::joint_matrix< + double, + sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, 8, + 8, LayoutC> C) { - matrix::joint_matrix + sycl::ext::oneapi::experimental::matrix::joint_matrix< + double, + sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, 8, 8, + LayoutC> D; #ifdef __NVPTX__ @@ -225,8 +294,9 @@ template &res, multi_ptr src, size_t stride) { - detail::joint_matrix_load_impl{}.load( - res, src, stride); + sycl::ext::oneapi::detail::joint_matrix_load_impl{} + .load(res, src, stride); } template &src, multi_ptr dst, size_t stride) { - detail::joint_matrix_store_impl{}.store( - src, dst, stride); + sycl::ext::oneapi::detail::joint_matrix_store_impl{} + .store(src, dst, stride); } template A, joint_matrix B, joint_matrix C) { - return detail::joint_matrix_mad_impl{} + return sycl::ext::oneapi::detail::joint_matrix_mad_impl< + T1, T2, M, K, N, LayoutA, LayoutB, LayoutC>{} .mad(A, B, C); } diff --git a/sycl/test/check_device_code/matrix/matrix-nvptx-double-test.cpp b/sycl/test/check_device_code/matrix/matrix-nvptx-double-test.cpp index 60fbdc458812e..408899e0897ea 100644 --- a/sycl/test/check_device_code/matrix/matrix-nvptx-double-test.cpp +++ b/sycl/test/check_device_code/matrix/matrix-nvptx-double-test.cpp @@ -50,15 +50,15 @@ int main() { joint_matrix sub_b; - //CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.load.c.row.stride.f64.p1f64(double addrspace(1)* %add.ptr.i, i32 8) #{{.*}} + //CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.load.c.row.stride.f64.p1f64(double addrspace(1)* %_arg_, i32 8) #{{.*}} joint_matrix_load(sg, sub_c, accC.get_pointer(), N); - //CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.a.row.stride.f64.p1f64(double addrspace(1)* %add.ptr.i54, i32 4) #{{.*}} + //CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.a.row.stride.f64.p1f64(double addrspace(1)* %_arg_4, i32 4) #{{.*}} joint_matrix_load(sg, sub_a, accA.get_pointer(), K); - //CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.b.row.stride.f64.p1f64(double addrspace(1)* %add.ptr.i65, i32 8) #{{.*}} + //CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.b.row.stride.f64.p1f64(double addrspace(1)* %_arg_9, i32 8) #{{.*}} joint_matrix_load(sg, sub_b, accB.get_pointer(), N); - //CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.row.row.f64(double %11, double %12, double %9, double %10) #{{.*}} + //CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.row.row.f64(double %3, double %4, double %1, double %2) #{{.*}} sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); - //CHECK: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.row.stride.f64.p1f64(double addrspace(1)* %add.ptr.i76, double %14, double %15, i32 8) #{{.*}} + //CHECK: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.row.stride.f64.p1f64(double addrspace(1)* %_arg_14, double %6, double %7, i32 8) #{{.*}} joint_matrix_store(sg, sub_c, accD.get_pointer(), N); }); }); @@ -84,15 +84,15 @@ int main() { joint_matrix sub_b; - //CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.load.c.col.stride.f64.p1f64(double addrspace(1)* %add.ptr.i, i32 8) #{{.*}} + //CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.load.c.col.stride.f64.p1f64(double addrspace(1)* %_arg_, i32 8) #{{.*}} joint_matrix_load(sg, sub_c, accC.get_pointer(), M); - //CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.a.col.stride.f64.p1f64(double addrspace(1)* %add.ptr.i54, i32 8) #{{.*}} + //CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.a.col.stride.f64.p1f64(double addrspace(1)* %_arg_4, i32 8) #{{.*}} joint_matrix_load(sg, sub_a, accA.get_pointer(), M); - //CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.b.col.stride.f64.p1f64(double addrspace(1)* %add.ptr.i65, i32 4) #{{.*}} + //CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.b.col.stride.f64.p1f64(double addrspace(1)* %_arg_9, i32 4) #{{.*}} joint_matrix_load(sg, sub_b, accB.get_pointer(), K); - //CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.col.col.f64(double %11, double %12, double %9, double %10) #{{.*}} + //CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.col.col.f64(double %3, double %4, double %1, double %2) #{{.*}} sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); - //CHECK: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64.p1f64(double addrspace(1)* %add.ptr.i76, double %14, double %15, i32 8) #{{.*}} + //CHECK: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64.p1f64(double addrspace(1)* %_arg_14, double %6, double %7, i32 8) #{{.*}} joint_matrix_store(sg, sub_c, accD.get_pointer(), M); }); });