Skip to content
This repository was archived by the owner on Mar 28, 2023. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 80 additions & 30 deletions SYCL/Matrix/joint_matrix_tensorcore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ uint16_t make_bf16(float x) {
return (uint16_t)*res;
}

template <typename T1, typename T2, size_t Big_N, size_t Big_K>
template <size_t Big_N, size_t Big_K, typename T1, typename T2>
T2 matrix_ref_mn(const int &m, const int &n, T1 *A, T1 *B, T2 *C) {
T2 res = C[m * Big_N + n];

Expand All @@ -80,7 +80,8 @@ T2 matrix_ref_mn(const int &m, const int &n, T1 *A, T1 *B, T2 *C) {
}

template <typename T1, typename T2, size_t Sub_Tiles_M, size_t Sub_Tiles_K,
size_t Sub_Tiles_N, size_t M, size_t K, size_t N, typename T3 = T1>
size_t Sub_Tiles_N, size_t M, size_t K, size_t N,
typename T3 = std::remove_const_t<T1>>
void test(queue &q) {

constexpr auto Big_M =
Expand All @@ -93,25 +94,26 @@ void test(queue &q) {
Sub_Tiles_K *
K; // total number of K dimension matrix elements for the "Big matrix".

T1 A[Big_M * Big_K];
T1 B[Big_K * Big_N];
T2 C[Big_M * Big_N];
T2 D[Big_M * Big_N];
std::remove_const_t<T1> A[Big_M * Big_K];
std::remove_const_t<T1> B[Big_K * Big_N];
std::remove_const_t<T2> C[Big_M * Big_N];
std::remove_const_t<T2> D[Big_M * Big_N];

for (int i = 0; i < Big_M * Big_N; i++) {
C[i] = 1;
D[i] = 0;
}

if constexpr (std::is_same<T1, uint16_t>::value) {
if constexpr (std::is_same<std::remove_const_t<T1>, uint16_t>::value) {
for (int i = 0; i < Big_M * Big_K; i++) {
A[i] = make_bf16(0.1f * (i % 10));
}

for (int i = 0; i < Big_K * Big_N; i++) {
B[i] = make_bf16(0.1f * (i % 10));
}
} else if constexpr (!std::is_same<T1, bfloat16>::value) {
} else if constexpr (!std::is_same<std::remove_const_t<T1>,
bfloat16>::value) {
for (int i = 0; i < Big_M * Big_K; i++) {
A[i] = i % 100;
}
Expand All @@ -121,41 +123,43 @@ void test(queue &q) {
}
}
{
buffer<T1, 1> bufA(A, range<1>(Big_M * Big_K));
buffer<T1, 1> bufB(B, range<1>(Big_K * Big_N));
buffer<T2, 1> bufC(C, range<1>(Big_M * Big_N));
buffer<T2, 1> bufD(D, range<1>(Big_M * Big_N));
if constexpr (std::is_same<std::remove_const_t<T1>, bfloat16>::value) {

// currently bfloat16 has to be initialized on device
if constexpr (std::is_same<T1, bfloat16>::value) {
buffer<bfloat16, 1> bufA(A, range<1>(Big_M * Big_K));
buffer<bfloat16, 1> bufB(B, range<1>(Big_K * Big_N));
q.submit([&](handler &cgh) {
accessor<T1, 1, access::mode::read_write, target::device> accA(bufA,
cgh);
accessor<bfloat16, 1, access::mode::write, target::device> accA(bufA,
cgh);

cgh.parallel_for<KernelName<bfloat16, class copyA, M, K, N>>(
cgh.parallel_for<KernelName<T1, class copyA, M, K, N>>(
range<1>(Big_M * Big_K), [=](item<1> item) {
auto i = item.get_linear_id();
accA[i] = 0.1f * (i % 10);
});
});

q.submit([&](handler &cgh) {
accessor<T1, 1, access::mode::read_write, target::device> accB(bufB,
cgh);
accessor<bfloat16, 1, access::mode::write, target::device> accB(bufB,
cgh);

cgh.parallel_for<KernelName<bfloat16, class copyB, M, K, N>>(
cgh.parallel_for<KernelName<T1, class copyB, M, K, N>>(
range<1>(Big_K * Big_N), [=](item<1> item) {
auto i = item.get_linear_id();
accB[i] = 0.1f * (i % 10);
});
});
}

buffer<T1, 1> bufA(A, range<1>(Big_M * Big_K));
buffer<T1, 1> bufB(B, range<1>(Big_K * Big_N));
buffer<T2, 1> bufC(C, range<1>(Big_M * Big_N));
buffer<std::remove_const_t<T2>, 1> bufD(D, range<1>(Big_M * Big_N));

q.submit([&](handler &cgh) {
accessor<T1, 1, access::mode::read_write, target::device> accA(bufA, cgh);
accessor<T1, 1, access::mode::read_write, target::device> accB(bufB, cgh);
accessor<T2, 1, access::mode::read_write, target::device> accC(bufC, cgh);
accessor<T2, 1, access::mode::read_write, target::device> accD(bufD, cgh);
accessor<T1, 1, access::mode::read, target::device> accA(bufA, cgh);
accessor<T1, 1, access::mode::read, target::device> accB(bufB, cgh);
accessor<T2, 1, access::mode::read, target::device> accC(bufC, cgh);
accessor<std::remove_const_t<T2>, 1, access::mode::write, target::device>
accD(bufD, cgh);

range<2> LocalRange = {1, N_THREADS_PER_MATRIX_OP};
range<2> GlobalRange = {Sub_Tiles_M,
Expand All @@ -177,7 +181,7 @@ void test(queue &q) {
joint_matrix<T3, matrix_use::b, K, N, matrix_layout::row_major>
sub_b;

joint_matrix<T2, matrix_use::accumulator, M, N,
joint_matrix<std::remove_const_t<T2>, matrix_use::accumulator, M, N,
matrix_layout::row_major>
sub_c;

Expand Down Expand Up @@ -216,14 +220,14 @@ void test(queue &q) {

for (int m = 0; m < Big_M; m++) {
for (int n = 0; n < Big_N; n++) {
if constexpr (std::is_same<T1, bfloat16>::value) {
auto res_device = matrix_ref_mn<T1, T2, Big_N, Big_K>(m, n, A, B, C);
if constexpr (std::is_same<std::remove_const_t<T1>, bfloat16>::value) {
auto res_device = matrix_ref_mn<Big_N, Big_K>(m, n, A, B, C);
assert(fabs(2 * (D[m * Big_N + n] - res_device)) /
(D[m * Big_N + n] + res_device) <
bf16_eps * 2);
} else {
assert((D[m * Big_N + n] ==
matrix_ref_mn<T1, T2, Big_N, Big_K>(m, n, A, B, C)));
assert(
(D[m * Big_N + n] == matrix_ref_mn<Big_N, Big_K>(m, n, A, B, C)));
}
}
}
Expand All @@ -241,36 +245,82 @@ int main() {
test<half, float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8, 16, 32>(Q);
test<half, float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 32, 16, 8>(Q);

test<const half, const float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16, 16,
16>(Q);
test<const half, const float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8, 16,
32>(Q);
test<const half, const float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 32, 16,
8>(Q);

// A/B/Accumulator half
test<half, half, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16, 16, 16>(Q);
test<half, half, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8, 16, 32>(Q);
test<half, half, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 32, 16, 8>(Q);

test<const half, const half, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16, 16,
16>(Q);
test<const half, const half, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8, 16,
32>(Q);
test<const half, const half, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 32, 16,
8>(Q);
}
if (computeCapability >= 7.2) {
test<int8_t, int32_t, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16, 16, 16>(Q);
test<int8_t, int32_t, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8, 16, 32>(Q);
test<int8_t, int32_t, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 32, 16, 8>(Q);

test<const int8_t, const int32_t, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16,
16, 16>(Q);
test<const int8_t, const int32_t, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8,
16, 32>(Q);
test<const int8_t, const int32_t, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 32,
16, 8>(Q);

test<uint8_t, int32_t, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16, 16, 16>(
Q);
test<uint8_t, int32_t, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8, 16, 32>(Q);
test<uint8_t, int32_t, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 32, 16, 8>(Q);

test<const uint8_t, const int32_t, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N,
16, 16, 16>(Q);
test<const uint8_t, const int32_t, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8,
16, 32>(Q);
test<const uint8_t, const int32_t, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N,
32, 16, 8>(Q);
}
if (computeCapability >= 8.0) {
test<double, double, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8, 4, 8>(Q);
test<const double, const double, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8,
4, 8>(Q);

// A/B bfloat16 using storage type
test<uint16_t, float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16, 16, 16>(Q);
test<uint16_t, float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8, 16, 32>(Q);
test<uint16_t, float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 32, 16, 8>(Q);

test<const uint16_t, const float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16,
16, 16>(Q);
test<const uint16_t, const float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8,
16, 32>(Q);
test<const uint16_t, const float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 32,
16, 8>(Q);

test<bfloat16, float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16, 16, 16>(Q);
test<bfloat16, float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8, 16, 32>(Q);
test<bfloat16, float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 32, 16, 8>(Q);

test<const bfloat16, const float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16,
16, 16>(Q);
test<const bfloat16, const float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8,
16, 32>(Q);
test<const bfloat16, const float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 32,
16, 8>(Q);

// A/B tf32
test<float, float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16, 8, 16,
precision::tf32>(Q);
test<const float, const float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16, 8,
16, precision::tf32>(Q);
}
return 0;
};