Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions sycl/include/CL/__spirv/spirv_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,13 @@ template <typename T, std::size_t R, std::size_t C,
extern SYCL_EXTERNAL size_t __spirv_JointMatrixWorkItemLengthINTEL(
__spv::__spirv_JointMatrixINTEL<T, R, C, L, S> *);

template <typename T, std::size_t R, std::size_t C,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern SYCL_EXTERNAL __ocl_vec_t<uint32_t, 2>
__spirv_JointMatrixGetElementCoordINTEL(
__spv::__spirv_JointMatrixINTEL<T, R, C, L, S> *, size_t i);

template <typename T, std::size_t R, std::size_t C,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
Expand Down
44 changes: 44 additions & 0 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <CL/__spirv/spirv_ops.hpp>
#include <sycl/detail/defines_elementary.hpp>
#include <sycl/feature_test.hpp>
#include <tuple>

namespace sycl {
__SYCL_INLINE_VER_NAMESPACE(_V1) {
Expand Down Expand Up @@ -250,6 +251,21 @@ class wi_element {
wi_element(joint_matrix<T, NumRows, NumCols, Use, Layout, Group> &Mat,
std::size_t i)
: M(Mat), idx(i) {}

std::tuple<uint32_t, uint32_t> get_coord() {
#ifdef __SYCL_DEVICE_ONLY__
__ocl_vec_t<uint32_t, 2> coord =
__spirv_JointMatrixGetElementCoordINTEL(M.spvm, idx);
const uint32_t row = coord[0];
const uint32_t col = coord[1];
return std::make_tuple(row, col);
#else
throw runtime_error("joint matrix is not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif // __SYCL_DEVICE_ONLY__
}

// Various Operations
operator T() {
#ifdef __SYCL_DEVICE_ONLY__
return __spirv_VectorExtractDynamic(M.spvm, idx);
Expand Down Expand Up @@ -332,6 +348,20 @@ class wi_element<uint16_t, NumRows, NumCols, Use, Layout, Group> {
wi_element(joint_matrix<uint16_t, NumRows, NumCols, Use, Layout, Group> &Mat,
std::size_t i)
: M(Mat), idx(i) {}

std::tuple<uint32_t, uint32_t> get_coord() {
#ifdef __SYCL_DEVICE_ONLY__
__ocl_vec_t<uint32_t, 2> coord =
__spirv_JointMatrixGetElementCoordINTEL(M.spvm, idx);
const uint32_t row = coord[0];
const uint32_t col = coord[1];
return std::make_tuple(row, col);
#else
throw runtime_error("joint matrix is not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif // __SYCL_DEVICE_ONLY__
}

operator uint16_t() {
#ifdef __SYCL_DEVICE_ONLY__
return __spirv_VectorExtractDynamic(M.spvm, idx);
Expand Down Expand Up @@ -482,6 +512,20 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
Layout, Group> &Mat,
std::size_t i)
: M(Mat), idx(i) {}

std::tuple<uint32_t, uint32_t> get_coord() {
#ifdef __SYCL_DEVICE_ONLY__
__ocl_vec_t<uint32_t, 2> coord =
__spirv_JointMatrixGetElementCoordINTEL(M.spvm, idx);
const uint32_t row = coord[0];
const uint32_t col = coord[1];
return std::make_tuple(row, col);
#else
throw runtime_error("joint matrix is not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif // __SYCL_DEVICE_ONLY__
}

operator sycl::ext::oneapi::bfloat16() {
#ifdef __SYCL_DEVICE_ONLY__
return __spirv_VectorExtractDynamic(M.spvm, idx);
Expand Down
112 changes: 112 additions & 0 deletions sycl/test/matrix/matrix-bfloat16-test-coord-basic.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
// RUN: %clangxx -fsycl -O2 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=2 %s -o %t.out
// RUN: %t.out
// XFAIL: *

#include <iostream>
#include <sycl/sycl.hpp>

using namespace sycl;
using namespace sycl::ext::oneapi::experimental::matrix;

#define SG_SZ 16

#define TN SG_SZ
#define TK 32

template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
public:
T *mat;

public:
T *get_data() { return mat; }
void set_data(T *data) { mat = data; }
big_matrix(T *data) : mat(data) {}
};

template <typename T, size_t M, size_t N>
void sum_rows_ref(
accessor<T, 2, access::mode::read, access::target::host_buffer> B,
accessor<int, 1, access::mode::read, access::target::host_buffer>
sum_rows) {
int sum_rows_ref[M] = {0};
for (size_t i = 0; i < M; i++) {
for (size_t j = 0; j < N; j++) {
sum_rows_ref[i] += B[i][j];
}
auto diff = sum_rows[i] - sum_rows_ref[i];
assert(std::fabs(static_cast<int>(diff)) <=
std::numeric_limits<int>::epsilon());
}
}

template <typename T, size_t M, size_t N>
void matrix_sum_rows(queue q, big_matrix<T, M, N> &B, nd_range<2> &r) {
buffer<int8_t, 2> bufB(B.get_data(), range<2>(M, N));
// size of vector is known because SG size of set by the user in this case
int sum_rows[M] = {0};
buffer<int> sum_rows_v(sum_rows, M); // there are total of tK/4 * 2, 16 rows
q.submit([&](handler &cgh) {
auto accB = bufB.get_access<access::mode::read_write>(cgh);

auto v = sum_rows_v.get_access<access::mode::atomic>(cgh);

cgh.parallel_for<class add_matrix>(
r, [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
const auto global_idx = spmd_item.get_global_id(0);
const auto global_idy = spmd_item.get_global_id(1);
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
const auto sg_starty = global_idy - spmd_item.get_local_id(1);

ext::oneapi::sub_group sg = spmd_item.get_sub_group();

joint_matrix<T, TK, TN, use::b> sub_b(sg);

joint_matrix_load(sg, sub_b,
accB.get_pointer() + (global_idx * (TK / 4) * N) +
sg_starty / SG_SZ * TN * 4,
N, layout::packed_b);

int32_t sum_local_rows[M] = {0};
auto tBData = sub_b.get_wi_data();

// each WI calculates local sum of rows
for (int i = 0; i < tBData.length(); ++i) {
// row and col holds global co_ordinates of the matrix
auto [row, col] = tBData[i].get_coord();
sum_local_rows[row] += tBData[i];

sum_local_rows[row] =
reduce_over_group(sg, sum_local_rows[row], sycl::plus<>());
// only Groups leader perform the global reduction
if (global_idy % SG_SZ == 0) {
atomic_fetch_add(v[row], sum_local_rows[row]);
}
}
}); // parallel for
}).wait();
sum_rows_ref<T, M, N>(bufB.get_access<access::mode::read>(),
sum_rows_v.get_access<access::mode::read>());
}

static constexpr size_t MATRIX_K = TK / 4 * 2;
static constexpr size_t MATRIX_N = TN * 4 * 2;
int8_t B[MATRIX_K][MATRIX_N];

int main() {
big_matrix<int8_t, MATRIX_K, MATRIX_N> MB((int8_t *)&B);

size_t NDRangeK = MATRIX_K / (TK / 4);
size_t NDRangeN = (MATRIX_N / 4) / TN;
queue q;
nd_range<2> r({NDRangeK, NDRangeN * SG_SZ}, {1, 1 * SG_SZ});

for (int i = 0; i < MATRIX_K; i++) {
for (int j = 0; j < MATRIX_N; j++) {
B[i][j] = i;
}
}

matrix_sum_rows<int8_t, MATRIX_K, MATRIX_N>(q, MB, r);

return 0;
}
Loading