Skip to content
8 changes: 8 additions & 0 deletions sycl/include/CL/__spirv/spirv_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,14 @@ template <typename T, std::size_t R, std::size_t C,
extern SYCL_EXTERNAL size_t __spirv_JointMatrixWorkItemLengthINTEL(
JOINT_MATRIX_INTEL(T, R, C, L, S, U) *);

template <typename T, std::size_t R, std::size_t C,
__spv::MatrixUse U = __spv::MatrixUse::Unnecessary,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern SYCL_EXTERNAL std::tuple<T, T>
__spirv_JointMatrixWorkItemElemCoord(JOINT_MATRIX_INTEL(T, R, C, L, S, U) *,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no such thing as std::tuple in SPIR-V. The instruction should return int2 and if we want to create a tuple for get_coord API, then we should read elements from this vector to create tuple.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comment. Can you please tell a bit more about this int2 type? Is there any documentation/ code that I can take a look?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a 2 elements vector. int2 is a spelling from OpenCL, but guess the appropriate alias should be known for DPCPP, see: https://registry.khronos.org/SYCL/specs/sycl-2020/html/sycl-2020.html#_aliases

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

@yubingex007-a11y yubingex007-a11y Oct 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eh, why we use __ocl_vec_t<int32_t, 2> instead of sycl::vec<int32_t, 2> here? @MrSidims

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please wait with the merge until the name for the instruction is picked. In the draft SPIR-V spec version it is JointMatrixGetElementCoordINTEL

size_t i);

template <typename T, std::size_t R, std::size_t C,
__spv::MatrixUse U = __spv::MatrixUse::Unnecessary,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
Expand Down
13 changes: 13 additions & 0 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <CL/__spirv/spirv_ops.hpp>
#include <sycl/detail/defines_elementary.hpp>
#include <sycl/feature_test.hpp>
#include <tuple>

namespace sycl {
__SYCL_INLINE_VER_NAMESPACE(_V1) {
Expand Down Expand Up @@ -256,6 +257,18 @@ class wi_element {
wi_element(joint_matrix<T, NumRows, NumCols, Use, Layout, Group> &Mat,
std::size_t i)
: M(Mat), idx(i) {}

// Functions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this comment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

std::tuple<size_t, size_t> get_coord() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you need to add this function to the specialization of wi_element for bfloat16 type?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: size_t -> uint32_t is probably better
same nit applicable to the code below

#ifdef __SYCL_DEVICE_ONLY__
return __spirv_JointMatrixWorkItemElemCoord(M.spvm, idx);
#else
throw runtime_error("joint matrix is not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif // __SYCL_DEVICE_ONLY__
}

// Various Operations
operator T() {
#ifdef __SYCL_DEVICE_ONLY__
return __spirv_VectorExtractDynamic(M.spvm, idx);
Expand Down
205 changes: 205 additions & 0 deletions sycl/test/matrix/matrix-bfloat16-test-coord.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
// RUN: %clangxx -fsycl -fsycl-targets=spir64_gen -DSYCL_EXT_ONEAPI_MATRIX_VERSION=2 -S -emit-llvm %s -o %t.out
#include <iostream>
#include <sycl/sycl.hpp>

using namespace sycl::ext::oneapi::experimental::matrix;
using bfloat16 = sycl::ext::oneapi::experimental::bfloat16;

static constexpr auto TILE_SZ = 16;
static constexpr auto TM = TILE_SZ - 1;
static constexpr auto TN = TILE_SZ - 1;
static constexpr auto TK = 2 * TILE_SZ - 2;

static constexpr auto SG_SZ = 16;

template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
public:
T *mat;

public:
T *get_data() { return mat; }
void set_data(T *data) { mat = data; }
big_matrix(T *data) : mat(data) {}
};

static constexpr size_t MATRIX_M = TM * 2;
static constexpr size_t MATRIX_N = TN * 2;
static constexpr size_t MATRIX_K = TK * 2;
bfloat16 A[MATRIX_M][MATRIX_K];
bfloat16 B[MATRIX_K / 2][MATRIX_N * 2];
unsigned short Aref[MATRIX_M][MATRIX_K];
unsigned short Bref[MATRIX_K / 2][MATRIX_N * 2];
float C[MATRIX_M][MATRIX_N];
float D[MATRIX_M][MATRIX_N];
int32_t *res_local_row;

template <typename T1, typename T2, size_t NUM_ROWS_A, size_t NUM_COLS_A,
size_t NUM_ROWS_B, size_t NUM_COLS_B, size_t NUM_ROWS_C,
size_t NUM_COLS_C>
void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
big_matrix<T2, NUM_ROWS_A, NUM_COLS_A> &A,
big_matrix<T2, NUM_ROWS_B, NUM_COLS_B> &B) {
size_t M = NUM_ROWS_C;
size_t N = NUM_COLS_C;
size_t K = NUM_COLS_A;
// B => K/4 x N*4, A => M x K, C => M, N
// stride should be X's cols, e.g., B's stirde = N*4
assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B * 2);
size_t NDRangeM = M / TM;
size_t NDRangeN = N / TN;
sycl::buffer<bfloat16, 2> bufA(A.get_data(), sycl::range<2>(M, K));
sycl::buffer<bfloat16, 2> bufB(B.get_data(), sycl::range<2>(K, N));
sycl::buffer<float, 2> bufC((float *)C.get_data(), sycl::range<2>(M, N));

sycl::buffer<int32_t, 1> res_local_row_buf(res_local_row,
sycl::range<1>(MATRIX_M));

sycl::queue q;
q.submit([&](sycl::handler &cgh) {
auto accC = bufC.get_access<sycl::access::mode::read_write>(cgh);
auto accA = bufA.get_access<sycl::access::mode::read_write>(cgh);
auto accB = bufB.get_access<sycl::access::mode::read_write>(cgh);
auto res_local_row_acc =
res_local_row_buf.get_access<sycl::access::mode::read_write>(cgh);

cgh.parallel_for<class imatrix>(
sycl::nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}),
[accA, accB, accC, M, N, K,
res_local_row_acc](sycl::nd_item<2> spmd_item)

{
// The submatrix API has to be accessed by all the workitems in a
// subgroup these functions will be called once by the subgroup no
// code divergence between the workitems
const auto global_idx = spmd_item.get_global_id(0);
const auto global_idy = spmd_item.get_global_id(1);
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
const auto sg_starty = global_idy - spmd_item.get_local_id(1);

sycl::ext::oneapi::sub_group sg = spmd_item.get_sub_group();
joint_matrix<bfloat16, TM, TK, use::a> sub_a(sg);
// For B, since current implementation does not support non-packed
// layout, users need to specify the updated VNNI sizes along with
// the packed_b layout. By default, the layout is row_major and size
// is (TK, TN).
joint_matrix<bfloat16, TK, TN, use::b> sub_b(sg);
joint_matrix<float, TM, TN, use::accumulator> sub_c(sg);

joint_matrix_load(sg, sub_c,
accC.get_pointer() + (sg_startx * TM) * N +
sg_starty / SG_SZ * TN,
N, layout::row_major);
for (int k = 0; k < K / TK; k += 1) { //
joint_matrix_load(
sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK,
K, layout::row_major);
// Assuming B data is already in VNNI format.
joint_matrix_load(sg, sub_b,
accB.get_pointer() + (k * TK / 2) * (N * 2) +
sg_starty / SG_SZ * TN * 2,
N * 2, layout::packed_b);
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
}
joint_matrix_store(sg, sub_c,
accC.get_pointer() + (sg_startx * TM) * N +
sg_starty / SG_SZ * TN,
N, layout::row_major);
// Element wise operation
auto tCData = sub_c.get_wi_data();

for (int i = 0; i < tCData.length(); ++i) {
size_t row, col;
std::tie(row, col) = tCData[i].get_coord();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can also use size_t [ row, col] =
to avoid calling tie

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

res_local_row_acc[row] += tCData[i];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you need to return res_local_row_acc and use it verify_function

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}
}); // parallel for
}).wait();
}

float make_fp32(short x) {
unsigned int y = x;
y = y << 16;
float *res = reinterpret_cast<float *>(&y);
return *res;
}

unsigned short make_bf16(float x) {
int *res = reinterpret_cast<int *>(&x);
*res = *res >> 16;
return (unsigned short)*res;
}

void matrix_multiply_ref(int *A_mem, int *B_mem, int *C_mem, int M, int N,
int K) {
// tiling
for (int m = 0; m < M; m++)
for (int n = 0; n < N; n++) {
for (int k = 0; k < K; k++) {
short *va = (short *)(A_mem + m * K + k);
short *vb = (short *)(B_mem + k * N + n);
float acc = *((float *)(C_mem + m * N + n));
// FIXME: Should we do reduce-add in another version?
for (int i = 0; i < 2; i++) {
acc += (make_fp32(va[i]) * make_fp32(vb[i]));
}
*((float *)(C_mem + m * N + n)) = acc;
}
}
}

int main() {
for (int i = 0; i < MATRIX_M; i++) {
for (int j = 0; j < MATRIX_K; j++) {
// Ee create bfloat16 from unsigned short since float-to-bfloat's
// conversion is not allowed.
A[i][j] = bfloat16::from_bits(make_bf16(1.0f * (i + j)));
Aref[i][j] = make_bf16(1.0f * (i + j));
}
}
for (int i = 0; i < MATRIX_K / 2; i++) {
for (int j = 0; j < MATRIX_N * 2; j++) {
B[i][j] = bfloat16::from_bits((make_bf16(2.0f * i + 3.0f * j)));
Bref[i][j] = make_bf16(2.0f * i + 3.0f * j);
}
}
for (int i = 0; i < MATRIX_M; i++) {
for (int j = 0; j < MATRIX_N; j++) {
C[i][j] = 1.0;
D[i][j] = 1.0;
}
}

big_matrix<float, MATRIX_M, MATRIX_N> MC((float *)&C);
big_matrix<float, MATRIX_M, MATRIX_N> MD((float *)&D);
big_matrix<bfloat16, MATRIX_M, MATRIX_K> MA((bfloat16 *)&A);
big_matrix<bfloat16, MATRIX_K / 2, MATRIX_N * 2> MB((bfloat16 *)&B);

res_local_row = (int32_t *)calloc(MATRIX_M, sizeof(int32_t));

matrix_multiply(MC, MA, MB);
matrix_multiply_ref((int32_t *)Aref, (int32_t *)Bref, (int32_t *)D, MATRIX_M,
MATRIX_N, MATRIX_K / 2);

bool res = true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is what I call verify_function.
matrix_multiply_ref should also calculate sum of rows and return that instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

for (int i = 0; i < MATRIX_M; i++) {
for (int j = 0; j < MATRIX_N; j++) {
if (C[i][j] != D[i][j])
res = false;
}
}
if (res)
std::cout << "passed\n";
else
std::cout << "failed\n";
for (int i = 0; i < MATRIX_M; i++) {
for (int j = 0; j < MATRIX_N; j++)
std::cout << C[i][j] << ", ";
std::cout << "\n";
}
std::cout << std::endl;
for (int i = 0; i < MATRIX_M; i++) {
for (int j = 0; j < MATRIX_N; j++)
std::cout << D[i][j] << ", ";
std::cout << "\n";
}
}