Skip to content
Closed
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 154 additions & 15 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,34 @@ __SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, a, 16, 16, int32_t, 2)
__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, b, 16, 16, int32_t, 2)
__SYCL_JOINT_MATRIX_OVERLOAD(int32_t, accumulator, 16, 16, int32_t, 8)

// single-bit
template <matrix_layout Layout>
struct joint_matrix<
uint32_t, matrix_use::a, 8, 4, Layout, sycl::sub_group,
typename std::enable_if_t<Layout == matrix_layout::row_major ||
Layout == matrix_layout::col_major>> {
joint_matrix() {
static_assert((Layout == matrix_layout::row_major),
"For the matrix_use::a case, matrix_layout::row_major must "
"be used for Bitwise MAD");
};
int32_t data[1];
};

template <matrix_layout Layout>
struct joint_matrix<
uint32_t, matrix_use::b, 4, 8, Layout, sycl::sub_group,
typename std::enable_if_t<Layout == matrix_layout::row_major ||
Layout == matrix_layout::col_major>> {
joint_matrix() {
static_assert((Layout == matrix_layout::col_major),
"For the matrix_use::b case, matrix_layout::col_major must "
"be used for Bitwise MAD");
};
int32_t data[1];
};
__SYCL_JOINT_MATRIX_OVERLOAD(int32_t, accumulator, 8, 8, int32_t, 2)

#undef __SYCL_JOINT_MATRIX_OVERLOAD
} // namespace experimental::matrix

Expand Down Expand Up @@ -235,6 +263,28 @@ struct joint_matrix_load_impl<
get_layout_id<Layout>());
}

} else if constexpr (std::is_same<T, double>::value) {
if constexpr (Use ==
sycl::ext::oneapi::experimental::matrix::matrix_use::a) {
__dmma_m8n8k4_ld_a(res.data, src.get(), stride,
get_layout_id<Layout>());
} else if constexpr (Use == sycl::ext::oneapi::experimental::matrix::
matrix_use::b) {
__dmma_m8n8k4_ld_b(res.data, src.get(), stride,
get_layout_id<Layout>());
} else if constexpr (Use == sycl::ext::oneapi::experimental::matrix::
matrix_use::accumulator) {
__dmma_m8n8k4_ld_c(res.data, src.get(), stride,
get_layout_id<Layout>());
}
} else if constexpr (NumRows == 8 && NumCols == 4) {
int32_t *tileptr = reinterpret_cast<int32_t *>(src.get());
__bmma_m8n8k128_ld_a_b1(res.data, tileptr, stride * 32,
get_layout_id<Layout>());
} else if constexpr (NumRows == 4 && NumCols == 8) {
int32_t *tileptr = reinterpret_cast<int32_t *>(src.get());
__bmma_m8n8k128_ld_b_b1(res.data, tileptr, stride * 32,
get_layout_id<Layout>());
} else if constexpr (std::is_same<T, int32_t>::value) {
if constexpr (NumRows == 16 && NumCols == 16) {
__imma_m16n16k16_ld_c(res.data, src.get(), stride,
Expand All @@ -245,6 +295,9 @@ struct joint_matrix_load_impl<
} else if constexpr (NumRows == 32 && NumCols == 8) {
__imma_m32n8k16_ld_c(res.data, src.get(), stride,
get_layout_id<Layout>());
} else if constexpr (NumRows == 8 && NumCols == 8) {
__bmma_m8n8k128_ld_c(res.data, src.get(), stride,
get_layout_id<Layout>());
}
} else if constexpr (std::is_same<T, float>::value) {
if constexpr (NumRows == 16 && NumCols == 16) {
Expand All @@ -257,20 +310,6 @@ struct joint_matrix_load_impl<
__hmma_m32n8k16_ld_c_f32(res.data, src.get(), stride,
get_layout_id<Layout>());
}
} else if constexpr (std::is_same<T, double>::value) {
if constexpr (Use ==
sycl::ext::oneapi::experimental::matrix::matrix_use::a) {
__dmma_m8n8k4_ld_a(res.data, src.get(), stride,
get_layout_id<Layout>());
} else if constexpr (Use == sycl::ext::oneapi::experimental::matrix::
matrix_use::b) {
__dmma_m8n8k4_ld_b(res.data, src.get(), stride,
get_layout_id<Layout>());
} else if constexpr (Use == sycl::ext::oneapi::experimental::matrix::
matrix_use::accumulator) {
__dmma_m8n8k4_ld_c(res.data, src.get(), stride,
get_layout_id<Layout>());
}
}
}
};
Expand Down Expand Up @@ -339,6 +378,9 @@ struct joint_matrix_store_impl<
} else if constexpr (std::is_same<T, double>::value) {
__dmma_m8n8k4_st_c_f64(dst.get(), src.data, stride,
get_layout_id<Layout>());
} else if constexpr (std::is_same<T, int32_t>::value) {
__bmma_m8n8k128_st_c_i32(dst.get(), src.data, stride,
get_layout_id<Layout>());
}
}
};
Expand Down Expand Up @@ -366,6 +408,31 @@ struct joint_matrix_mad_impl {
C);
};

template <std::size_t M, std::size_t K, std::size_t N,
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutC,
class BinaryOperation, typename Cond = void>
struct joint_matrix_bmad_impl {
sycl::ext::oneapi::experimental::matrix::joint_matrix<
int32_t, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator,
M, N, LayoutC, sycl::sub_group>
bmad(sycl::ext::oneapi::experimental::matrix::joint_matrix<
uint32_t, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M,
K, sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major,
sycl::sub_group>
A,
sycl::ext::oneapi::experimental::matrix::joint_matrix<
uint32_t, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K,
N, sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major,
sycl::sub_group>
B,
sycl::ext::oneapi::experimental::matrix::joint_matrix<
int32_t,
sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M,
N, LayoutC, sycl::sub_group>
C,
BinaryOperation Op);
};

template <sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutA,
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutB>
constexpr int get_layout_pair_id();
Expand Down Expand Up @@ -495,14 +562,59 @@ struct joint_matrix_mad_impl<
get_layout_pair_id<LayoutA, LayoutB>(), 0);
}
}
} else if constexpr (std::is_same<T1, double>::value) {
} else if constexpr (M == 8 && N == 8 && K == 4) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this change related to bmad addition?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No this is a superficial/non-important change that I made just for better consistency of the if constexpr statements in this function.

__dmma_m8n8k4_mma_f64(D.data, A.data, B.data, C.data,
get_layout_pair_id<LayoutA, LayoutB>(), 0);
}
return D;
}
};

template <std::size_t M, std::size_t K, std::size_t N,
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutC,
class BinaryOperation>
struct joint_matrix_bmad_impl<
M, K, N, LayoutC, BinaryOperation,
typename std::enable_if_t<(
LayoutC ==
sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major ||
LayoutC == sycl::ext::oneapi::experimental::matrix::matrix_layout::
col_major)>> {
sycl::ext::oneapi::experimental::matrix::joint_matrix<
int32_t, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator,
M, N, LayoutC, sycl::sub_group>
bmad(sycl::ext::oneapi::experimental::matrix::joint_matrix<
uint32_t, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M,
K, sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major,
sycl::sub_group>
A,
sycl::ext::oneapi::experimental::matrix::joint_matrix<
uint32_t, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K,
N, sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major,
sycl::sub_group>
B,
sycl::ext::oneapi::experimental::matrix::joint_matrix<
int32_t,
sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M,
N, LayoutC, sycl::sub_group>
C,
BinaryOperation Op) {
sycl::ext::oneapi::experimental::matrix::joint_matrix<
int32_t,
sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, N,
LayoutC, sycl::sub_group>
D;
if constexpr (std::is_same<BinaryOperation,
sycl::bit_and<uint32_t>>::value) {
__bmma_m8n8k128_mma_and_popc_b1(D.data, A.data, B.data, C.data, 1);
} else if constexpr (std::is_same<BinaryOperation,
sycl::bit_xor<uint32_t>>::value) {
__bmma_m8n8k128_mma_xor_popc_b1(D.data, A.data, B.data, C.data, 1);
}
return D;
}
};

} // namespace detail

namespace experimental::matrix {
Expand Down Expand Up @@ -573,6 +685,33 @@ joint_matrix_mad(
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
}

template <typename Group, std::size_t M, std::size_t K, std::size_t N,
matrix_layout LayoutC, class BinaryOperation>
joint_matrix<int32_t, matrix_use::accumulator, M, N, LayoutC, Group>
joint_matrix_bmad(
Group sg,
joint_matrix<uint32_t, matrix_use::a, M, K, matrix_layout::row_major, Group>
A,
joint_matrix<uint32_t, matrix_use::b, K, N, matrix_layout::col_major, Group>
B,
joint_matrix<int32_t, matrix_use::accumulator, M, N, LayoutC, Group> C,
BinaryOperation Op) {
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
return sycl::ext::oneapi::detail::joint_matrix_bmad_impl<M, K, N, LayoutC,
BinaryOperation>{}
.bmad(A, B, C, Op);
#else
(void)sg;
(void)A;
(void)B;
(void)C;
(void)Op;
throw runtime_error("joint_matrix_bmad is "
"only supported by CUDA devices",
PI_INVALID_DEVICE);
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
}

} // namespace experimental::matrix
} // namespace oneapi
} // namespace ext
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// REQUIRES: cuda

// RUN: %clangxx -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_75 -DSYCL_EXT_ONEAPI_MATRIX=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s

#include <CL/sycl.hpp>

using namespace sycl;
using namespace sycl::ext::oneapi::experimental::matrix;

// M, N, K define the sizes of dimensions of the three matrix types (a, b,
// accumulator) used per subgroup operation.
constexpr int M = 8; // number of rows of accumulator,
// number of cols of b.
constexpr int N = 8; // number of cols of accumulator,
// number of rows of a.
constexpr int K = 128; // number of cols of a/number of rows of b.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you missed to make the change here.
K should be 4 here.
Can you please add a comment where you are making these changes?
Basically, saying that the underlying intrinsics are expecting a shape of K equals to number of total bits, not number of elements.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks I forgot that. Fixed now, and I also updated the test so it will work with the legacy pass-manager.

I've added a more detailed comment describing Bitwise Dot Product and how this dictates the relation between the number of Array elements used for A/B arrays and the number of single-bit matrix elements that the A/B arrays represent. I've also correspondingly updated the test in intel/llvm-test-suite and the tensor cores matrix extension PR #4695.


uint32_t A[M * K / 32];
uint32_t B[K * N / 32];
int32_t C[M * N];
int32_t D[M * N];

int main() {

buffer<uint32_t, 1> bufA(A, range<1>(M * K / 32));
buffer<uint32_t, 1> bufB(B, range<1>(K * N / 32));
buffer<int32_t, 1> bufC(C, range<1>(M * N));
buffer<int32_t, 1> bufD(D, range<1>(M * N));

queue q;

q.submit([&](handler &cgh) {
auto accC = bufC.get_access<access::mode::read_write>(cgh);
auto accA = bufA.get_access<access::mode::read_write>(cgh);
auto accB = bufB.get_access<access::mode::read_write>(cgh);
auto accD = bufD.get_access<access::mode::read_write>(cgh);

cgh.parallel_for<class row_col>(
nd_range<2>({1, 32}, {1, 32}),
[=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] {
sycl::sub_group sg = item.get_sub_group();

joint_matrix<int32_t, matrix_use::accumulator, M, N,
matrix_layout::row_major>
sub_c;

joint_matrix<uint32_t, matrix_use::a, M, K, matrix_layout::row_major>
sub_a;

joint_matrix<uint32_t, matrix_use::b, K, N, matrix_layout::col_major>
sub_b;

//CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m8n8k128.load.c.row.stride.s32.p1i32(i32 addrspace(1)* %_arg_, i32 8) #{{.*}}
joint_matrix_load(sg, sub_c, accC.get_pointer(), N);
//CHECK: tail call i32 @llvm.nvvm.wmma.m8n8k128.load.a.row.stride.b1.p0i32(i32* %call.ascast.i.i63.i, i32 128) #{{.*}}
joint_matrix_load(sg, sub_a, accA.get_pointer(), K);
//CHECK: tail call i32 @llvm.nvvm.wmma.m8n8k128.load.b.col.stride.b1.p0i32(i32* %call.ascast.i.i.i, i32 128) #{{.*}}
joint_matrix_load(sg, sub_b, accB.get_pointer(), K);
//CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m8n8k128.mma.xor.popc.row.col.b1(i32 %3, i32 %4, i32 %1, i32 %2) #{{.*}}
sub_c = joint_matrix_bmad(sg, sub_a, sub_b, sub_c,
sycl::bit_xor<uint32_t>());
//CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m8n8k128.mma.and.popc.row.col.b1(i32 %3, i32 %4, i32 %6, i32 %7) #{{.*}}
sub_c = joint_matrix_bmad(sg, sub_a, sub_b, sub_c,
sycl::bit_and<uint32_t>());
//CHECK: tail call void @llvm.nvvm.wmma.m8n8k128.store.d.row.stride.s32.p1i32(i32 addrspace(1)* %_arg_14, i32 %9, i32 %10, i32 8) #{{.*}}
joint_matrix_store(sg, sub_c, accD.get_pointer(), N);
});
});

return 0;
};