Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 33 additions & 23 deletions sycl/include/CL/__spirv/spirv_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,20 @@
#endif

#ifdef __SYCL_DEVICE_ONLY__

#ifdef __SYCL_EXT_ONEAPI_MATRIX_USE__
#define JOINT_MATRIX_INTEL(T, R, C, L, S, U) \
__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U>
#else
#define JOINT_MATRIX_INTEL(T, R, C, L, S, U) \
__spv::__spirv_JointMatrixINTEL<T, R, C, L, S>
#endif // __SYCL_EXT_ONEAPI_MATRIX_USE__

template <typename T, std::size_t R, std::size_t C,
__spv::MatrixUse U = __spv::MatrixUse::Unnecessary,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *
extern SYCL_EXTERNAL JOINT_MATRIX_INTEL(T, R, C, L, S, U) *
__spirv_JointMatrixLoadINTEL(T *Ptr, std::size_t Stride,
__spv::MatrixLayout Layout = L,
__spv::Scope::Flag Sc = S, int MemOperand = 0);
Expand All @@ -36,7 +45,7 @@ template <typename T, std::size_t R, std::size_t C,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern SYCL_EXTERNAL void __spirv_JointMatrixStoreINTEL(
T *Ptr, __spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *Object,
T *Ptr, JOINT_MATRIX_INTEL(T, R, C, L, S, U) *Object,
std::size_t Stride, __spv::MatrixLayout Layout = L,
__spv::Scope::Flag Sc = S, int MemOperand = 0);

Expand All @@ -48,11 +57,11 @@ template <typename T1, typename T2, std::size_t M, std::size_t K, std::size_t N,
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T2, M, N, LC, S, UC> *
extern SYCL_EXTERNAL JOINT_MATRIX_INTEL(T2, M, N, LC, S, UC) *
__spirv_JointMatrixMadINTEL(
__spv::__spirv_JointMatrixINTEL<T1, M, K, LA, S, UA> *A,
__spv::__spirv_JointMatrixINTEL<T1, K, N, LB, S, UB> *B,
__spv::__spirv_JointMatrixINTEL<T2, M, N, LC, S, UC> *C,
JOINT_MATRIX_INTEL(T1, M, K, LA, S, UA) *A,
JOINT_MATRIX_INTEL(T1, K, N, LB, S, UB) *B,
JOINT_MATRIX_INTEL(T2, M, N, LC, S, UC) *C,
__spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);

template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
Expand All @@ -63,11 +72,11 @@ template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S, UC> *
extern SYCL_EXTERNAL JOINT_MATRIX_INTEL(T2, M, N, LC, S, UC) *
__spirv_JointMatrixUUMadINTEL(
__spv::__spirv_JointMatrixINTEL<T1, M, K, LA, S, UA> *A,
__spv::__spirv_JointMatrixINTEL<T2, K, N, LB, S, UB> *B,
__spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S, UC> *C,
JOINT_MATRIX_INTEL(T1, M, K, LA, S, UA) *A,
JOINT_MATRIX_INTEL(T2, K, N, LB, S, UB) *B,
JOINT_MATRIX_INTEL(T3, M, N, LC, S, UC) *C,
__spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);

template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
Expand All @@ -78,11 +87,11 @@ template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S, UC> *
extern SYCL_EXTERNAL JOINT_MATRIX_INTEL(T3, M, N, LC, S, UC) *
__spirv_JointMatrixUSMadINTEL(
__spv::__spirv_JointMatrixINTEL<T1, M, K, LA, S, UA> *A,
__spv::__spirv_JointMatrixINTEL<T2, K, N, LB, S, UB> *B,
__spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S, UC> *C,
JOINT_MATRIX_INTEL(T1, M, K, LA, S, UA) *A,
JOINT_MATRIX_INTEL(T2, K, N, LB, S, UB) *B,
JOINT_MATRIX_INTEL(T3, M, N, LC, S, UC) *C,
__spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);

template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
Expand All @@ -93,38 +102,39 @@ template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S, UC> *
extern SYCL_EXTERNAL JOINT_MATRIX_INTEL(T3, M, N, LC, S, UC) *
__spirv_JointMatrixSUMadINTEL(
__spv::__spirv_JointMatrixINTEL<T1, M, K, LA, S, UA> *A,
__spv::__spirv_JointMatrixINTEL<T2, K, N, LB, S, UB> *B,
__spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S, UC> *C,
JOINT_MATRIX_INTEL(T1, M, K, LA, S, UA) *A,
JOINT_MATRIX_INTEL(T2, K, N, LB, S, UB) *B,
JOINT_MATRIX_INTEL(T3, M, N, LC, S, UC) *C,
__spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);

template <typename T, std::size_t R, std::size_t C,
__spv::MatrixUse U = __spv::MatrixUse::Unnecessary,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *
extern SYCL_EXTERNAL JOINT_MATRIX_INTEL(T, R, C, L, S, U) *
__spirv_CompositeConstruct(const T v);

template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
__spv::MatrixLayout L,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern SYCL_EXTERNAL size_t __spirv_JointMatrixWorkItemLengthINTEL(
__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *);
JOINT_MATRIX_INTEL(T, R, C, L, S, U) *);

template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
__spv::MatrixLayout L,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern SYCL_EXTERNAL T __spirv_VectorExtractDynamic(
__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *, size_t i);
JOINT_MATRIX_INTEL(T, R, C, L, S, U) *, size_t i);

template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
__spv::MatrixLayout L,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *
__spirv_VectorInsertDynamic(__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *,
extern SYCL_EXTERNAL JOINT_MATRIX_INTEL(T, R, C, L, S, U) *
__spirv_VectorInsertDynamic(JOINT_MATRIX_INTEL(T, R, C, L, S, U) *,
T val, size_t i);
#undef JOINT_MATRIX_INTEL

#ifndef __SPIRV_BUILTIN_DECLARATIONS__
#error \
Expand Down
9 changes: 9 additions & 0 deletions sycl/include/CL/__spirv/spirv_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ enum class MatrixUse : uint32_t {
// information to SPIRV translator.
// The long term solution would be to introduce a matrix type in Clang and use
// it instead of this member.
#ifdef __SYCL_EXT_ONEAPI_MATRIX_USE__
template <typename T, std::size_t R, std::size_t C, MatrixLayout L,
Scope::Flag S = Scope::Flag::Subgroup,
MatrixUse U = MatrixUse::Unnecessary>
Expand All @@ -145,6 +146,14 @@ struct __spirv_JointMatrixINTEL {
[R][C][static_cast<size_t>(L) + 1][static_cast<size_t>(S) + 1]
[static_cast<size_t>(U) + 1];
};
#else
template <typename T, std::size_t R, std::size_t C, MatrixLayout L,
Scope::Flag S = Scope::Flag::Subgroup>
struct __spirv_JointMatrixINTEL {
T(*Value)
[R][C][static_cast<size_t>(L) + 1][static_cast<size_t>(S) + 1];
};
#endif // __SYCL_EXT_ONEAPI_MATRIX_USE__

} // namespace __spv

Expand Down
14 changes: 6 additions & 8 deletions sycl/include/sycl/ext/oneapi/matrix/matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,14 @@

#include <sycl/feature_test.hpp>

// the default is matrix-jit-use but existing tests in llvm-test-suite won't
// fail because we have the "unnecessary" use value
#if (SYCL_EXT_ONEAPI_MATRIX == 1)
#include <sycl/ext/oneapi/matrix/matrix-jit.hpp>
#include <sycl/ext/oneapi/matrix/static-query.hpp>
#if (SYCL_EXT_ONEAPI_MATRIX == 3)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JackAKirk it doesn't look like a correct usage of SYCL_EXT_ONEAPI_MATRIX (users shouldn't pass such macro during compilation, see #6662 (comment) ). What would be an appropriate replacement for it?

Copy link
Contributor

@JackAKirk JackAKirk Oct 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There isn't an intention of using any macro in the future in order to decide the APIs to use. Initially for the initial Tensor Cores experimental implementation we followed the existing pattern used by the AOT and JIT intel implementations by setting a new value to the feature test macro. The SYCL spec itself doesn't appear to forbid this, although it sounds like the extension document itself contradicts the implementation in this regard!

We plan to add a separate macro analogous to SYCL_EXT_ONEAPI_MATRIX_LEGACY_API. We could name it SYCL_EXT_ONEAPI_MATRIX_LEGACY_CUDA_API for example.
Since the interfaces for the combined Intel/CUDA implementation has been decided I could also open a PR that supports both Intel and CUDA backends within the same implementation. #5920 is blocked on a spirv builtin atm, so it sounds like we shouldn't wait on this PR before starting the combined implementation. We could add the macro for the legacy version at the same time.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just talked with @dkhaldi, and it sounds like she prefers to have a different name for the macro that chooses the API.

#include <sycl/ext/oneapi/matrix/matrix-tensorcore.hpp>
#endif
#if (SYCL_EXT_ONEAPI_MATRIX == 2)
#elif (SYCL_EXT_ONEAPI_MATRIX == 2 || __SYCL_EXT_ONEAPI_MATRIX_USE__)
#include <sycl/ext/oneapi/matrix/matrix-jit-use.hpp>
#include <sycl/ext/oneapi/matrix/static-query-use.hpp>
#endif
#if (SYCL_EXT_ONEAPI_MATRIX == 3)
#include <sycl/ext/oneapi/matrix/matrix-tensorcore.hpp>
#elif (SYCL_EXT_ONEAPI_MATRIX == 1)
#include <sycl/ext/oneapi/matrix/matrix-jit.hpp>
#include <sycl/ext/oneapi/matrix/static-query.hpp>
#endif
9 changes: 1 addition & 8 deletions sycl/include/sycl/feature_test.hpp.in
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,7 @@ __SYCL_INLINE_VER_NAMESPACE(_V1) {
#define SYCL_EXT_INTEL_DEVICE_INFO 3
#define SYCL_EXT_ONEAPI_SUB_GROUP_MASK 1
#define SYCL_EXT_ONEAPI_LOCAL_MEMORY 1
// As for SYCL_EXT_ONEAPI_MATRIX:
// 1- provides AOT initial implementation for AMX for the experimental matrix
// extension
// 2- provides JIT implementation (target agnostic) for the
// experimental matrix extension
#ifndef SYCL_EXT_ONEAPI_MATRIX
#define SYCL_EXT_ONEAPI_MATRIX 2
#endif
#define SYCL_EXT_ONEAPI_MATRIX 1
#define SYCL_EXT_ONEAPI_ASSERT 1
#define SYCL_EXT_ONEAPI_COMPLEX_ALGORITHMS 1
#define SYCL_EXT_ONEAPI_DISCARD_QUEUE_EVENTS 1
Expand Down
2 changes: 1 addition & 1 deletion sycl/test/matrix/matrix-bf16-test-SG-16.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %clangxx -fsycl -O2 %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX=1
// RUN: %clangxx -fsycl -O2 %s -o %t.out
#include <iostream>
#include <sycl/sycl.hpp>

Expand Down
2 changes: 1 addition & 1 deletion sycl/test/matrix/matrix-bf16-test.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %clangxx -fsycl -O2 %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX=1
// RUN: %clangxx -fsycl -O2 %s -o %t.out
#include <iostream>
#include <sycl/sycl.hpp>

Expand Down
2 changes: 1 addition & 1 deletion sycl/test/matrix/matrix-bfloat16-test.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %clangxx -fsycl -O2 %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX=1
// RUN: %clangxx -fsycl -O2 %s -o %t.out
#include <iostream>
#include <sycl/sycl.hpp>

Expand Down
2 changes: 1 addition & 1 deletion sycl/test/matrix/matrix-elemwise-ops.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %clangxx -fsycl -O2 %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX=1
// RUN: %clangxx -fsycl -O2 %s -o %t.out

#include <iostream>
#include <sycl/sycl.hpp>
Expand Down
2 changes: 1 addition & 1 deletion sycl/test/matrix/matrix-int8-test-SG-16.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %clangxx -fsycl -O2 %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX=1
// RUN: %clangxx -fsycl -O2 %s -o %t.out
#include <iostream>
#include <sycl/sycl.hpp>

Expand Down
2 changes: 1 addition & 1 deletion sycl/test/matrix/matrix-int8-test-use.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %clangxx -fsycl -fsycl-device-only -O2 -S -emit-llvm -o - %s | FileCheck %s
// RUN: %clangxx -fsycl -fsycl-device-only -D__SYCL_EXT_ONEAPI_MATRIX_USE__ -O2 -S -emit-llvm -o - %s | FileCheck %s

// CHECK-DAG: %spirv.JointMatrixINTEL._char_12_48_4_3_0 = type { [12 x [48 x [5 x [4 x [1 x i8]]]]] addrspace(4)* }
// CHECK-DAG: %spirv.JointMatrixINTEL._int_12_12_4_3_2 = type { [12 x [12 x [5 x [4 x [3 x i32]]]]] addrspace(4)* }
Expand Down
8 changes: 4 additions & 4 deletions sycl/test/matrix/matrix-int8-test.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// RUN: %clangxx -DSYCL_EXT_ONEAPI_MATRIX=1 -fsycl -fsycl-device-only -O2 -S -emit-llvm -o - %s | FileCheck %s
// RUN: %clangxx -fsycl -fsycl-device-only -O2 -S -emit-llvm -o - %s | FileCheck %s

// CHECK-DAG: %spirv.JointMatrixINTEL._char_12_48_0_3_3 = type { [12 x [48 x [1 x [4 x [4 x i8]]]]] addrspace(4)* }
// CHECK-DAG: %spirv.JointMatrixINTEL._int_12_12_0_3_3 = type { [12 x [12 x [1 x [4 x [4 x i32]]]]] addrspace(4)* }
// CHECK-DAG: %spirv.JointMatrixINTEL._char_48_12_3_3_3 = type { [48 x [12 x [4 x [4 x [4 x i8]]]]] addrspace(4)* }
// CHECK-DAG: %spirv.JointMatrixINTEL._char_12_48_0_3 = type { [12 x [48 x [1 x [4 x i8]]]] addrspace(4)* }
// CHECK-DAG: %spirv.JointMatrixINTEL._int_12_12_0_3 = type { [12 x [12 x [1 x [4 x i32]]]] addrspace(4)* }
// CHECK-DAG: %spirv.JointMatrixINTEL._char_48_12_3_3 = type { [48 x [12 x [4 x [4 x i8]]]] addrspace(4)* }

#include <iostream>
#include <sycl/sycl.hpp>
Expand Down
2 changes: 1 addition & 1 deletion sycl/test/matrix/query.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %clangxx -DSYCL_EXT_ONEAPI_MATRIX=1 -fsycl -o query %s
// RUN: %clangxx -fsycl -o query %s
#include <iostream>
#include <sycl/sycl.hpp>

Expand Down