diff --git a/SYCL/Matrix/joint_matrix_tensorcore.cpp b/SYCL/Matrix/joint_matrix_tensorcore.cpp index f6985edcb6..8582ef955f 100644 --- a/SYCL/Matrix/joint_matrix_tensorcore.cpp +++ b/SYCL/Matrix/joint_matrix_tensorcore.cpp @@ -58,7 +58,7 @@ uint16_t make_bf16(float x) { return (uint16_t)*res; } -template +template T2 matrix_ref_mn(const int &m, const int &n, T1 *A, T1 *B, T2 *C) { T2 res = C[m * Big_N + n]; @@ -80,7 +80,8 @@ T2 matrix_ref_mn(const int &m, const int &n, T1 *A, T1 *B, T2 *C) { } template + size_t Sub_Tiles_N, size_t M, size_t K, size_t N, + typename T3 = std::remove_const_t> void test(queue &q) { constexpr auto Big_M = @@ -93,17 +94,17 @@ void test(queue &q) { Sub_Tiles_K * K; // total number of K dimension matrix elements for the "Big matrix". - T1 A[Big_M * Big_K]; - T1 B[Big_K * Big_N]; - T2 C[Big_M * Big_N]; - T2 D[Big_M * Big_N]; + std::remove_const_t A[Big_M * Big_K]; + std::remove_const_t B[Big_K * Big_N]; + std::remove_const_t C[Big_M * Big_N]; + std::remove_const_t D[Big_M * Big_N]; for (int i = 0; i < Big_M * Big_N; i++) { C[i] = 1; D[i] = 0; } - if constexpr (std::is_same::value) { + if constexpr (std::is_same, uint16_t>::value) { for (int i = 0; i < Big_M * Big_K; i++) { A[i] = make_bf16(0.1f * (i % 10)); } @@ -111,7 +112,8 @@ void test(queue &q) { for (int i = 0; i < Big_K * Big_N; i++) { B[i] = make_bf16(0.1f * (i % 10)); } - } else if constexpr (!std::is_same::value) { + } else if constexpr (!std::is_same, + bfloat16>::value) { for (int i = 0; i < Big_M * Big_K; i++) { A[i] = i % 100; } @@ -121,29 +123,25 @@ void test(queue &q) { } } { - buffer bufA(A, range<1>(Big_M * Big_K)); - buffer bufB(B, range<1>(Big_K * Big_N)); - buffer bufC(C, range<1>(Big_M * Big_N)); - buffer bufD(D, range<1>(Big_M * Big_N)); + if constexpr (std::is_same, bfloat16>::value) { - // currently bfloat16 has to be initialized on device - if constexpr (std::is_same::value) { + buffer bufA(A, range<1>(Big_M * Big_K)); + buffer bufB(B, range<1>(Big_K * Big_N)); q.submit([&](handler &cgh) { - accessor accA(bufA, - cgh); + accessor accA(bufA, + cgh); - cgh.parallel_for>( + cgh.parallel_for>( range<1>(Big_M * Big_K), [=](item<1> item) { auto i = item.get_linear_id(); accA[i] = 0.1f * (i % 10); }); }); - q.submit([&](handler &cgh) { - accessor accB(bufB, - cgh); + accessor accB(bufB, + cgh); - cgh.parallel_for>( + cgh.parallel_for>( range<1>(Big_K * Big_N), [=](item<1> item) { auto i = item.get_linear_id(); accB[i] = 0.1f * (i % 10); @@ -151,11 +149,17 @@ void test(queue &q) { }); } + buffer bufA(A, range<1>(Big_M * Big_K)); + buffer bufB(B, range<1>(Big_K * Big_N)); + buffer bufC(C, range<1>(Big_M * Big_N)); + buffer, 1> bufD(D, range<1>(Big_M * Big_N)); + q.submit([&](handler &cgh) { - accessor accA(bufA, cgh); - accessor accB(bufB, cgh); - accessor accC(bufC, cgh); - accessor accD(bufD, cgh); + accessor accA(bufA, cgh); + accessor accB(bufB, cgh); + accessor accC(bufC, cgh); + accessor, 1, access::mode::write, target::device> + accD(bufD, cgh); range<2> LocalRange = {1, N_THREADS_PER_MATRIX_OP}; range<2> GlobalRange = {Sub_Tiles_M, @@ -177,7 +181,7 @@ void test(queue &q) { joint_matrix sub_b; - joint_matrix, matrix_use::accumulator, M, N, matrix_layout::row_major> sub_c; @@ -216,14 +220,14 @@ void test(queue &q) { for (int m = 0; m < Big_M; m++) { for (int n = 0; n < Big_N; n++) { - if constexpr (std::is_same::value) { - auto res_device = matrix_ref_mn(m, n, A, B, C); + if constexpr (std::is_same, bfloat16>::value) { + auto res_device = matrix_ref_mn(m, n, A, B, C); assert(fabs(2 * (D[m * Big_N + n] - res_device)) / (D[m * Big_N + n] + res_device) < bf16_eps * 2); } else { - assert((D[m * Big_N + n] == - matrix_ref_mn(m, n, A, B, C))); + assert( + (D[m * Big_N + n] == matrix_ref_mn(m, n, A, B, C))); } } } @@ -241,36 +245,82 @@ int main() { test(Q); test(Q); + test(Q); + test(Q); + test(Q); + // A/B/Accumulator half test(Q); test(Q); test(Q); + + test(Q); + test(Q); + test(Q); } if (computeCapability >= 7.2) { test(Q); test(Q); test(Q); + test(Q); + test(Q); + test(Q); + test( Q); test(Q); test(Q); + + test(Q); + test(Q); + test(Q); } if (computeCapability >= 8.0) { test(Q); + test(Q); // A/B bfloat16 using storage type test(Q); test(Q); test(Q); + test(Q); + test(Q); + test(Q); + test(Q); test(Q); test(Q); + test(Q); + test(Q); + test(Q); + // A/B tf32 test(Q); + test(Q); } return 0; };