diff --git a/SYCL/Matrix/XMX8/get_coord_bf16_gemm.cpp b/SYCL/Matrix/XMX8/get_coord_bf16_gemm.cpp new file mode 100644 index 0000000000..2890298389 --- /dev/null +++ b/SYCL/Matrix/XMX8/get_coord_bf16_gemm.cpp @@ -0,0 +1,26 @@ +//==----------- get_coord_bf16_gemm.cpp - DPC++ joint_matrix---------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// REQUIRES: matrix-xmx8 + +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 +// RUN: %CPU_RUN_PLACEHOLDER %t.out +// RUN: %GPU_RUN_PLACEHOLDER %t.out +// XFAIL:* + +#include +#include +#include + +using namespace sycl; +using namespace sycl::ext::intel; +using namespace sycl::ext::oneapi::experimental::matrix; +using bfloat16 = sycl::ext::oneapi::bfloat16; + +#define SG_SZ 8 + +#include "../get_coord_bf16_gemm_impl.hpp" diff --git a/SYCL/Matrix/XMX8/get_coord_bf16_matA.cpp b/SYCL/Matrix/XMX8/get_coord_bf16_matA.cpp new file mode 100644 index 0000000000..e80464c1ae --- /dev/null +++ b/SYCL/Matrix/XMX8/get_coord_bf16_matA.cpp @@ -0,0 +1,26 @@ +//==----------- get_coord_bf16_matA.cpp - DPC++ joint_matrix---------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// REQUIRES: matrix-xmx8 + +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 +// RUN: %CPU_RUN_PLACEHOLDER %t.out +// RUN: %GPU_RUN_PLACEHOLDER %t.out +// XFAIL:* + +#include +#include +#include + +using namespace sycl; +using namespace sycl::ext::intel; +using namespace sycl::ext::oneapi::experimental::matrix; +using bfloat16 = sycl::ext::oneapi::bfloat16; + +#define SG_SZ 8 + +#include "../get_coord_bf16_matA_impl.hpp" diff --git a/SYCL/Matrix/XMX8/get_coord_bf16_matB.cpp b/SYCL/Matrix/XMX8/get_coord_bf16_matB.cpp new file mode 100644 index 0000000000..6c7bbf7da1 --- /dev/null +++ b/SYCL/Matrix/XMX8/get_coord_bf16_matB.cpp @@ -0,0 +1,26 @@ +//==----------- get_coord_bf16_matB.cpp - DPC++ joint_matrix---------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// REQUIRES: matrix-xmx8 + +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 +// RUN: %CPU_RUN_PLACEHOLDER %t.out +// RUN: %GPU_RUN_PLACEHOLDER %t.out +// XFAIL:* + +#include +#include +#include + +using namespace sycl; +using namespace sycl::ext::intel; +using namespace sycl::ext::oneapi::experimental::matrix; +using bfloat16 = sycl::ext::oneapi::bfloat16; + +#define SG_SZ 8 + +#include "../get_coord_bf16_matB_impl.hpp" diff --git a/SYCL/Matrix/get_coord_bf16_gemm.cpp b/SYCL/Matrix/get_coord_bf16_gemm.cpp new file mode 100644 index 0000000000..61fac9cc98 --- /dev/null +++ b/SYCL/Matrix/get_coord_bf16_gemm.cpp @@ -0,0 +1,27 @@ +//==----------- get_coord_bf16_gemm.cpp - DPC++ joint_matrix---------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// REQUIRES: matrix + +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 +// RUN: %CPU_RUN_PLACEHOLDER %t.out +// RUN: %GPU_RUN_PLACEHOLDER %t.out +// XFAIL:* + +#include +#include +#include + +using namespace sycl; +using namespace sycl::ext::intel; +using namespace sycl::ext::oneapi; +using namespace sycl::ext::oneapi::experimental::matrix; +using bfloat16 = sycl::ext::oneapi::bfloat16; + +#define SG_SZ 16 + +#include "get_coord_bf16_gemm_impl.hpp" diff --git a/SYCL/Matrix/get_coord_bf16_gemm_impl.hpp b/SYCL/Matrix/get_coord_bf16_gemm_impl.hpp new file mode 100644 index 0000000000..8b36f20c27 --- /dev/null +++ b/SYCL/Matrix/get_coord_bf16_gemm_impl.hpp @@ -0,0 +1,208 @@ +#define TM 8 +#define TN SG_SZ +#define TK 16 + +static constexpr size_t MATRIX_M = TM * 2; +static constexpr size_t MATRIX_N = TN * 2; +static constexpr size_t MATRIX_K = TK * 2; + +#define BF16_EPSILON 0.00781250 + +template struct big_matrix { +private: + T *mat; + +public: + T *get_data() { return mat; } + void set_data(T *data) { mat = data; } + big_matrix(T *data) : mat(data) {} +}; + +// clang-format off +/* +Here's how the data is distributed +W0 --> 0 1 2 3 4 5 6 7 +wi [0,0] -> i=0, [0, 0] wi [0,1] --> i=0, [0, 1] wi [0,15] --> i=0, [0, 15] + i=1, [1, 0] i=1, [1, 1] i=1, [1, 15] + i=2, [2, 0] i=2, [2, 1] ... + ... .... + i=7, [7, 0] i=7, [7, 1] +*/ +// clang-format on +std::tuple get_coord_ref(int i, int wi_number) { + return std::make_tuple(i, wi_number); +} + +float sum_rows[MATRIX_M] = {0}; + +template +void matrix_multiply(big_matrix &C, big_matrix &A, + big_matrix &B) { + size_t NDRangeM = M / TM; + size_t NDRangeN = N / TN; + buffer bufA(A.get_data(), range<2>(M, K)); + buffer bufB(B.get_data(), range<2>(K, N)); + buffer bufC((float *)C.get_data(), range<2>(M, N)); + + buffer sum_rows_v(sum_rows, M); // there are total of M rows + + queue q; + q.submit([&](handler &cgh) { + auto accC = bufC.get_access(cgh); + auto accA = bufA.get_access(cgh); + auto accB = bufB.get_access(cgh); + + auto v = sum_rows_v.get_access(cgh); + auto os = sycl::stream(100000, 6144, cgh); + + cgh.parallel_for( + nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), + [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] + + { + // The submatrix API has to be accessed by all the workitems in a + // subgroup these functions will be called once by the subgroup no + // code divergence between the workitems + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + sub_group sg = spmd_item.get_sub_group(); + joint_matrix + sub_a; + // For B, we assume B has been already VNNIed. + joint_matrix + sub_b; + joint_matrix sub_c; + + joint_matrix_load(sg, sub_c, + accC.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, layout::row_major); + for (int k = 0; k < K / TK; k += 1) { // + joint_matrix_load( + sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + K); + joint_matrix_load(sg, sub_b, + accB.get_pointer() + (k * TK / 2) * (N * 2) + + sg_starty / SG_SZ * TN * 2, + N * 2); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + } + joint_matrix_store(sg, sub_c, + accC.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, layout::row_major); + + float sum_local_rows[M] = {0}; // 8 local rows, M total + auto data = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); + + // Keep track of rows handled in this WI + int32_t handled_rows[M] = {-1}; + size_t + global_index; // Index into the result array that holds the sums. + + for (int i = 0; i < data.length(); ++i) { + auto dataItem = data[i]; + auto [row, col] = dataItem.get_coord(); + // get_coord_ref(i, spmd_item.get_local_id(1)); + global_index = row + global_idx * TM; + + sum_local_rows[global_index] += data[i]; + + handled_rows[global_index] = 1; + } + + for (int j = 0; j < M; j++) { + if (handled_rows[j] == 1) { + global_index = j; + sum_local_rows[global_index] = reduce_over_group( + sg, sum_local_rows[global_index], sycl::plus<>()); + // only Groups leader perform the global reduction + if (global_idy % SG_SZ == 0) { + sycl::atomic_ref + aref(v[global_index]); + aref.fetch_add(sum_local_rows[global_index]); + } + } + } + }); // parallel for + }).wait(); +} + +bfloat16 A[MATRIX_M][MATRIX_K]; +bfloat16 B[MATRIX_K / 2][MATRIX_N * 2]; +float C[MATRIX_M][MATRIX_N]; +float D[MATRIX_M][MATRIX_N]; + +float make_fp32(bfloat16 x) { + unsigned int y = *((int *)&x); + y = y << 16; + float *res = reinterpret_cast(&y); + return *res; +} + +void matrix_multiply_ref(int *A_mem, int *B_mem, int *C_mem, int M, int N, + int K) { + for (int m = 0; m < M; m++) + for (int n = 0; n < N; n++) { + for (int k = 0; k < K; k++) { + // Because B was assumed VNNIed + bfloat16 *va = (bfloat16 *)(A_mem + m * K + k); + bfloat16 *vb = (bfloat16 *)(B_mem + k * N + n); + float acc = *((float *)(C_mem + m * N + n)); + for (int i = 0; i < 2; i++) { + acc += (make_fp32(va[i]) * make_fp32(vb[i])); + } + *((float *)(C_mem + m * N + n)) = acc; + } + } +} + +int main() { + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_K; j++) { + A[i][j] = bfloat16(1.0f * (i + j)); + } + } + for (int i = 0; i < MATRIX_K / 2; i++) { + for (int j = 0; j < MATRIX_N * 2; j++) { + B[i][j] = bfloat16(2.0f * i + 3.0f * j); + } + } + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) { + C[i][j] = 1.0; + D[i][j] = 1.0; + } + } + + big_matrix MC((float *)&C); + big_matrix MD((float *)&D); + big_matrix MA((bfloat16 *)&A); + big_matrix MB((bfloat16 *)&B); + matrix_multiply(MC, MA, MB); + matrix_multiply_ref((int32_t *)A, (int32_t *)B, (int32_t *)D, MATRIX_M, + MATRIX_N, MATRIX_K / 2); + + bool res = true; + float sum_rows_ref[MATRIX_M] = {0}; + + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) { + // std::cout << C[i][j] << " "; + if ((fabs(C[i][j]) - fabs(D[i][j])) > BF16_EPSILON) + res = false; + sum_rows_ref[i] += C[i][j]; + } + if ((fabs(sum_rows_ref[i]) - fabs(sum_rows[i])) > BF16_EPSILON) + res = false; + // std::cout << "\n"; + } + std::cout << (res ? "passed" : "failed") << std::endl; + return !res; +} diff --git a/SYCL/Matrix/get_coord_bf16_matA.cpp b/SYCL/Matrix/get_coord_bf16_matA.cpp new file mode 100644 index 0000000000..e5570c9abb --- /dev/null +++ b/SYCL/Matrix/get_coord_bf16_matA.cpp @@ -0,0 +1,27 @@ +//==----------- get_coord_bf16_matA.cpp - DPC++ joint_matrix---------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// REQUIRES: matrix + +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 +// RUN: %CPU_RUN_PLACEHOLDER %t.out +// RUN: %GPU_RUN_PLACEHOLDER %t.out +// XFAIL:* + +#include +#include +#include + +using namespace sycl; +using namespace sycl::ext::intel; +using namespace sycl::ext::oneapi; +using namespace sycl::ext::oneapi::experimental::matrix; +using bfloat16 = sycl::ext::oneapi::bfloat16; + +#define SG_SZ 16 + +#include "get_coord_bf16_matA_impl.hpp" diff --git a/SYCL/Matrix/get_coord_bf16_matA_impl.hpp b/SYCL/Matrix/get_coord_bf16_matA_impl.hpp new file mode 100644 index 0000000000..3d10598fa2 --- /dev/null +++ b/SYCL/Matrix/get_coord_bf16_matA_impl.hpp @@ -0,0 +1,187 @@ +#define TM 8 +#define TN SG_SZ +#define TK 32 + +template struct big_matrix { +public: + T *mat; + +public: + T *get_data() { return mat; } + void set_data(T *data) { mat = data; } + big_matrix(T *data) : mat(data) {} +}; + +template +void sum_rows_ref(host_accessor A, + host_accessor sum_rows) { + int sum_rows_ref[M] = {0}; + for (size_t i = 0; i < M; i++) { + for (size_t j = 0; j < K; j++) { + sum_rows_ref[i] += A[i][j]; + } + auto diff = sum_rows[i] - sum_rows_ref[i]; + assert(std::fabs(static_cast(diff)) <= + std::numeric_limits::epsilon()); + } +} + +// clang-format off +/* +Here's how the data is distributed among work items +0 0 0 0 +/ +/ +1 1 1 1 +/ +/ +2 2 2 2 +/ +/ +3 3 3 3 +W0 --> 0 0 1 1 2 2 3 3 .... 7 7 +wi [0,0] -> i=0, [0, 0] wi [0,1] --> i=0, [0, 2] wi [0,15] --> i=0, [0, 30] + i=1, [0, 1] i=1, [0, 3] i=1, [0, 31] + i=2, [1, 0] i=2, [1, 2] i=2, [1, 30] + i=3, [1, 1] i=3, [1, 3] i=3, [1, 31] + i=4, [2, 0] i=4, [2, 2] ... + i=5, [2, 1] i=5, [2, 3] + ... .... + i=14,[7, 0] i=14, [7, 2] + i=15,[7, 1] i=15, [7, 3] i=15, [7, 31] +*/ +//clang-format on +std::tuple get_coord_ref(int i, int wi_number) { + return std::make_tuple(i/2, ((i%2) + (wi_number*2))); +} + +//clang-format off +/* +Here's how the distribution of the A matrix looks like for this test case +x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x +x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x +x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x +x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x +x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x +x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x +x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x +x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x +<--------------------------------- SG1 ---------------------------------> +x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x +x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x +x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x +x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x +x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x +x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x +x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x +x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x +<0> <1> <2> <3> <4> <5> <6> <7> ..... WORK ITEMS +Each work item has 16 elements <8 rows and 2 cols of the original matrix> +the data_slice in holds the matrix elements in the following order: +0 0 0 0 + / + / +1 1 1 1 + / + / +2 2 2 2 + / + / +3 3 3 3 +W0 --> 0 0 1 1 2 2 3 3 .... 7 7 +*/ +//clang-format on +template +void matrix_sum_rows(queue q, big_matrix &A, nd_range<2> &r) { + buffer bufA(A.get_data(), range<2>(M, K)); + // size of vector is known because SG size of set by the user in this case + int sum_rows[M] = {0}; + buffer sum_rows_v(sum_rows, M); // there are total of M rows + q.submit([&](handler &cgh) { + auto accA = bufA.get_access(cgh); + + auto v = sum_rows_v.get_access(cgh); + auto os = sycl::stream(100000, 6144, cgh); + + cgh.parallel_for( + r, [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + ext::oneapi::sub_group sg = spmd_item.get_sub_group(); + + // TM = 8, TK = 32 + joint_matrix + sub_a; + + joint_matrix_load( + sg, sub_a, accA.get_pointer() + (global_idx * TM * K) + TK, + K); + + // calculate sum of rows in sum_rows_v[8], there are 8 rows in sub_a + int32_t sum_local_rows[M] = {0}; // 8 local rows, M total + // sub_a has 8x32 elements, 16 elements per WI, 2 per WI per row + auto data = sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); + + size_t global_index; // Index into the result array that holds the sums. + + // Keep track of rows handled in this WI + int32_t handled_rows[M] = {-1}; + + // each WI calculates local sum of rows + for (int i = 0; i < data.length(); ++i) { + // get the index of the element in the submatrix + auto data_item = data[i]; + auto [row, col] = data_item.get_coord(); + global_index = row + global_idx*TM; + + sum_local_rows[global_index] += data[i]; + + handled_rows[global_index] = 1; + } + + for (int j=0; j < M; j++) { + if (handled_rows[j] == 1) { + global_index = j; + sum_local_rows[global_index] = reduce_over_group( + sg, sum_local_rows[global_index], + sycl::plus<>()); + // only Groups leader perform the global reduction + if (global_idy % SG_SZ == 0) { + atomic_fetch_add(v[global_index], + sum_local_rows[global_index]); + } + } + } + }); // parallel for + }).wait(); + sum_rows_ref(bufA.get_host_access(), sum_rows_v.get_host_access()); +} + + +static constexpr size_t MATRIX_M = TM * 2; +static constexpr size_t MATRIX_K = TK * 2; +int8_t A[MATRIX_M][MATRIX_K]; + +int main() { + big_matrix MA((int8_t *)&A); + + size_t NDRangeM = MATRIX_M / TM; + size_t NDRangeK = MATRIX_K / TK; + queue q; + nd_range<2> r({NDRangeM, NDRangeK * SG_SZ}, {1, 1 * SG_SZ}); + + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_K; j++) { + A[i][j] = i; + } + } + + matrix_sum_rows(q, MA, r); + + std::cout << "Passed\n"; + + return 0; +} diff --git a/SYCL/Matrix/get_coord_bf16_matB.cpp b/SYCL/Matrix/get_coord_bf16_matB.cpp new file mode 100644 index 0000000000..d3f89079ce --- /dev/null +++ b/SYCL/Matrix/get_coord_bf16_matB.cpp @@ -0,0 +1,27 @@ +//==----------- get_coord_bf16_matB.cpp - DPC++ joint_matrix---------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// REQUIRES: matrix + +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 +// RUN: %CPU_RUN_PLACEHOLDER %t.out +// RUN: %GPU_RUN_PLACEHOLDER %t.out +// XFAIL:* + +#include +#include +#include + +using namespace sycl; +using namespace sycl::ext::intel; +using namespace sycl::ext::oneapi; +using namespace sycl::ext::oneapi::experimental::matrix; +using bfloat16 = sycl::ext::oneapi::bfloat16; + +#define SG_SZ 16 + +#include "get_coord_bf16_matB_impl.hpp" diff --git a/SYCL/Matrix/get_coord_bf16_matB_impl.hpp b/SYCL/Matrix/get_coord_bf16_matB_impl.hpp new file mode 100644 index 0000000000..8326b986b1 --- /dev/null +++ b/SYCL/Matrix/get_coord_bf16_matB_impl.hpp @@ -0,0 +1,218 @@ +#define TN SG_SZ +#define TK 32 + +template struct big_matrix { +public: + T *mat; + +public: + T *get_data() { return mat; } + void set_data(T *data) { mat = data; } + big_matrix(T *data) : mat(data) {} +}; + +template +void sum_cols_ref(host_accessor B, + host_accessor sum_cols) { + int sum_cols_ref[N] = {0}; + for (size_t j = 0; j < N; j++) { + for (size_t i = 0; i < M; i++) { + sum_cols_ref[j] += B[i][j]; + } + auto diff = sum_cols[j] - sum_cols_ref[j]; + assert(std::fabs(static_cast(diff)) <= + std::numeric_limits::epsilon()); + } +} + +// clang-format off +/* + Here is a demonstration of how matrix B will be divided across + work items for this test case. + < --------------- 128 ----------------------------------> + x x x x x x x x x x x x x x x x .......... x x x x x x ^ + x x x x x x x x x x x x x x x x .......... x x x x x x 16 + x x x x x x x x x x x x x x x x .......... x x x x x x | + ..... | + x x x x x x x x x x x x x x x x .......... x x x x x x | + x x x x x x x x x x x x x x x x .......... x x x x x x v + + --------------- 64 ----------------> + x x x x x x .......... x x x x x x ^ + x x x x x x .......... x x x x x x 8 + x x x x x x .......... x x x x x x | <-- part of (VNNI-ed) + ..... | original matrix each SG + x x x x x x .......... x x x x x x | holds + x x x x x x .......... x x x x x x v + < WI0 > < WI15 > + <-------- 16 -------------> + x x x .......... x x x ^ + x x x .......... x x x | + x x x .......... x x x | <-- part of (non-VNNI-ed) original matrix + ..... | each SG holds + x x x .......... x x x | + x x x .......... x x x | + x x x .......... x x x 32 + x x x .......... x x x | + x x x .......... x x x | + x x x .......... x x x | + x x x .......... x x x | + x x x .......... x x x | + x x x .......... x x x v + If we dividie the above matrix across 16 (SG_SZ) work items, + each WI will hold 32 elements. And these 32 elements will be + 8x4 chunks as shown in the VNNI-ed matrix figure. +*/ + +// The total distribution among the WIs in ALL the sub-groups is as follows: +// This is useful to figure out the the global index is to be calculated + +/* +W0 --> 0 0 0 0 1 1 1 1 ... 7 7 7 7 --> total 32 elements +wi [0,0] --> i=0, [0, 0] wi [0,1] --> i=0, [0, 4] wi [0,15] --> i=0, [0, 60] | wi [0,16] --> i=0, [0, 64] + i=1, [0, 1] i=1, [0, 5] i=1, [0, 61] | i=1, [0, 65] + i=2, [0, 2] i=2, [0, 6] i=2, [0, 62] | i=2, [0, 66] + i=3, [0, 3] i=3, [0, 7] i=3, [0, 63] | i=3, [0, 67] + i=4, [1, 0] i=4, [1, 4] i=4, [1, 60] | .... + i=5, [1, 1] i=5, [1, 5] i=5, [1, 61] | + i=6, [1, 2] i=6, [1, 6] i=6, [1, 62] | + i=7, [1, 3] i=7, [1, 7] i=7, [1, 63] | + ... ... .... | + i=28,[7, 0] i=28,[7, 4] i=28,[7, 60] | i=28, [7, 124] + i=29,[7, 1] i=29,[7, 5] i=29,[7, 61] | i=29, [7, 125] + i=30,[7, 2] i=30,[7, 6] i=30,[7, 62] | i=30, [7, 126] + i=31,[7, 3] i=31,[7, 7] i=31,[7, 63] | i=31, [7, 127] +---------------------------------------------------------------------------------------- --------------------------- +wi [1,0] --> i=0, [8, 0] + i=1, [8, 1] + i=2, [8, 2] + i=3, [8, 2] + ... + i=28, [15, 0] + i=29, [15, 1] + i=30, [15, 2] + i=31, [15, 3] +*/ + +// The following is the distribution among WIs in a SINGLE SG. +/* +W0 --> 0 0 0 0 1 1 1 1 ... 7 7 7 7 --> total 32 elements +wi [0,0] -> i=0, [0, 0] wi [0,1] --> i=0, [0, 4] wi [0,15] --> i=0, [0, 60] | + i=1, [0, 1] i=1, [0, 5] i=1, [0, 61] | + i=2, [0, 2] i=2, [0, 6] i=2, [0, 62] | + i=3, [0, 3] i=3, [0, 7] i=3, [0, 63] | + i=4, [1, 0] i=4, [1, 4] i=4, [1, 60] | + i=5, [1, 1] i=5, [1, 5] i=5, [1, 61] | + i=6, [1, 2] i=6, [1, 6] i=6, [1, 62] | + i=7, [1, 3] i=7, [1, 7] i=7, [1, 63] | + ... ... .... | + i=28,[7, 0] i=28,[7, 4] i=28,[7, 60] | + i=29,[7, 1] i=29,[7, 5] i=29,[7, 61] | + i=30,[7, 2] i=30,[7, 6] i=30,[7, 62] | + i=31,[7, 3] i=31,[7, 7] i=31,[7, 63] | +*/ +// clang-format on + +template +void matrix_sum_cols(queue q, big_matrix &B, nd_range<2> &r) { + buffer bufB(B.get_data(), range<2>(M, N)); + // size of vector is known because SG size of set by the user in this case + int sum_cols[N] = {0}; + buffer sum_cols_v(sum_cols, N); // there are total of tK/4 * 2, 16 rows + q.submit([&](handler &cgh) { + auto accB = bufB.get_access(cgh); + + auto v = sum_cols_v.get_access(cgh); + auto os = sycl::stream(100000, 6144, cgh); + + cgh.parallel_for( + r, [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + ext::oneapi::sub_group sg = spmd_item.get_sub_group(); + + // TK = 32, TN = 16 + joint_matrix + sub_b; + + joint_matrix_load(sg, sub_b, + accB.get_pointer() + (global_idx * (TK / 4) * N) + + sg_starty / SG_SZ * TN * 4, + N); + + int32_t sum_local_cols[N] = {0}; // 4 local cols, N total + // sub_b has 32x16 elements, 32 elements per WI, 4 per WI per row + auto wiData = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); + + size_t + global_index; // Index into the result array that holds the sums. + + // Keep track of cols handled in this WI + int32_t handled_cols[N] = {-1}; + + // each WI calculates local sum of cols + for (int i = 0; i < wiData.length(); ++i) { + // get the index of the element in the submatrix + auto dataItem = wiData[i]; + auto [row, col] = dataItem.get_coord(); + + // Calculation of global index + int sg_idx = (int)global_idy / SG_SZ; + global_index = col + sg_idx * 4 /*VNNI_FACTOR*/ * SG_SZ; + sum_local_cols[global_index] += wiData[i]; + handled_cols[global_index] = 1; + } + + for (int j = 0; j < N; j++) { + if (handled_cols[j] == 1) { + global_index = j; + sum_local_cols[global_index] = reduce_over_group( + sg, sum_local_cols[global_index], sycl::plus<>()); + // TODO: Do we need a reduce_over_grp? Adding it does not + // make any difference in result + atomic_fetch_add(v[global_index], sum_local_cols[global_index]); + } + } + }); // parallel for + }).wait(); + sum_cols_ref(bufB.get_host_access(), sum_cols_v.get_host_access()); +} + +// TK = 32, TN = 16 +static constexpr size_t MATRIX_K = TK / 4 * 2; // 16 +static constexpr size_t MATRIX_N = TN * 4 * 2; // 128 +int8_t B[MATRIX_K][MATRIX_N]; + +/* < --------------- 128 ----------------------------------> + x x x x x x x x x x x x x x x x .......... x x x x x x ^ + x x x x x x x x x x x x x x x x .......... x x x x x x 16 + x x x x x x x x x x x x x x x x .......... x x x x x x | + ..... | + x x x x x x x x x x x x x x x x .......... x x x x x x | + x x x x x x x x x x x x x x x x .......... x x x x x x v +*/ +int main() { + big_matrix MB((int8_t *)&B); + + size_t NDRangeK = MATRIX_K / (TK / 4); + size_t NDRangeN = (MATRIX_N / 4) / TN; + queue q; + nd_range<2> r({NDRangeK, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}); + + for (int i = 0; i < MATRIX_K; i++) { + for (int j = 0; j < MATRIX_N; j++) { + B[i][j] = i; + } + } + + matrix_sum_cols(q, MB, r); + + std::cout << "Passed\n"; + + return 0; +}