-
Notifications
You must be signed in to change notification settings - Fork 802
[SYCL][CUDA][MATRIX] joint_matrix_bmad implementation #5363
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 13 commits
d34df92
04a4a34
48ab1f3
fc7ebbd
af13da9
0f423c6
f35956c
6c11a2c
95df95d
82f1996
95385cf
a693cd0
cabcee6
c05d2a1
c36bea2
48978cc
445a41f
7457400
0283942
42e2b17
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -81,6 +81,34 @@ __SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, a, 16, 16, int32_t, 2) | |
| __SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, b, 16, 16, int32_t, 2) | ||
| __SYCL_JOINT_MATRIX_OVERLOAD(int32_t, accumulator, 16, 16, int32_t, 8) | ||
|
|
||
| // single-bit | ||
| template <matrix_layout Layout> | ||
| struct joint_matrix< | ||
| uint32_t, matrix_use::a, 8, 4, Layout, sycl::sub_group, | ||
| typename std::enable_if_t<Layout == matrix_layout::row_major || | ||
| Layout == matrix_layout::col_major>> { | ||
| joint_matrix() { | ||
| static_assert((Layout == matrix_layout::row_major), | ||
| "For the matrix_use::a case, matrix_layout::row_major must " | ||
| "be used for Bitwise MAD"); | ||
| }; | ||
| int32_t data[1]; | ||
| }; | ||
|
|
||
| template <matrix_layout Layout> | ||
| struct joint_matrix< | ||
| uint32_t, matrix_use::b, 4, 8, Layout, sycl::sub_group, | ||
| typename std::enable_if_t<Layout == matrix_layout::row_major || | ||
| Layout == matrix_layout::col_major>> { | ||
| joint_matrix() { | ||
| static_assert((Layout == matrix_layout::col_major), | ||
| "For the matrix_use::b case, matrix_layout::col_major must " | ||
| "be used for Bitwise MAD"); | ||
| }; | ||
| int32_t data[1]; | ||
| }; | ||
| __SYCL_JOINT_MATRIX_OVERLOAD(int32_t, accumulator, 8, 8, int32_t, 2) | ||
|
|
||
| #undef __SYCL_JOINT_MATRIX_OVERLOAD | ||
| } // namespace experimental::matrix | ||
|
|
||
|
|
@@ -235,6 +263,28 @@ struct joint_matrix_load_impl< | |
| get_layout_id<Layout>()); | ||
| } | ||
|
|
||
| } else if constexpr (std::is_same<T, double>::value) { | ||
| if constexpr (Use == | ||
| sycl::ext::oneapi::experimental::matrix::matrix_use::a) { | ||
| __dmma_m8n8k4_ld_a(res.data, src.get(), stride, | ||
| get_layout_id<Layout>()); | ||
| } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: | ||
| matrix_use::b) { | ||
| __dmma_m8n8k4_ld_b(res.data, src.get(), stride, | ||
| get_layout_id<Layout>()); | ||
| } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: | ||
| matrix_use::accumulator) { | ||
| __dmma_m8n8k4_ld_c(res.data, src.get(), stride, | ||
| get_layout_id<Layout>()); | ||
| } | ||
| } else if constexpr (NumRows == 8 && NumCols == 4) { | ||
| int32_t *tileptr = reinterpret_cast<int32_t *>(src.get()); | ||
| __bmma_m8n8k128_ld_a_b1(res.data, tileptr, stride * 32, | ||
| get_layout_id<Layout>()); | ||
| } else if constexpr (NumRows == 4 && NumCols == 8) { | ||
| int32_t *tileptr = reinterpret_cast<int32_t *>(src.get()); | ||
| __bmma_m8n8k128_ld_b_b1(res.data, tileptr, stride * 32, | ||
| get_layout_id<Layout>()); | ||
| } else if constexpr (std::is_same<T, int32_t>::value) { | ||
| if constexpr (NumRows == 16 && NumCols == 16) { | ||
| __imma_m16n16k16_ld_c(res.data, src.get(), stride, | ||
|
|
@@ -245,6 +295,9 @@ struct joint_matrix_load_impl< | |
| } else if constexpr (NumRows == 32 && NumCols == 8) { | ||
| __imma_m32n8k16_ld_c(res.data, src.get(), stride, | ||
| get_layout_id<Layout>()); | ||
| } else if constexpr (NumRows == 8 && NumCols == 8) { | ||
| __bmma_m8n8k128_ld_c(res.data, src.get(), stride, | ||
| get_layout_id<Layout>()); | ||
| } | ||
| } else if constexpr (std::is_same<T, float>::value) { | ||
| if constexpr (NumRows == 16 && NumCols == 16) { | ||
|
|
@@ -257,20 +310,6 @@ struct joint_matrix_load_impl< | |
| __hmma_m32n8k16_ld_c_f32(res.data, src.get(), stride, | ||
| get_layout_id<Layout>()); | ||
| } | ||
| } else if constexpr (std::is_same<T, double>::value) { | ||
| if constexpr (Use == | ||
| sycl::ext::oneapi::experimental::matrix::matrix_use::a) { | ||
| __dmma_m8n8k4_ld_a(res.data, src.get(), stride, | ||
| get_layout_id<Layout>()); | ||
| } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: | ||
| matrix_use::b) { | ||
| __dmma_m8n8k4_ld_b(res.data, src.get(), stride, | ||
| get_layout_id<Layout>()); | ||
| } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: | ||
| matrix_use::accumulator) { | ||
| __dmma_m8n8k4_ld_c(res.data, src.get(), stride, | ||
| get_layout_id<Layout>()); | ||
| } | ||
| } | ||
| } | ||
| }; | ||
|
|
@@ -339,6 +378,9 @@ struct joint_matrix_store_impl< | |
| } else if constexpr (std::is_same<T, double>::value) { | ||
| __dmma_m8n8k4_st_c_f64(dst.get(), src.data, stride, | ||
| get_layout_id<Layout>()); | ||
| } else if constexpr (std::is_same<T, int32_t>::value) { | ||
| __bmma_m8n8k128_st_c_i32(dst.get(), src.data, stride, | ||
| get_layout_id<Layout>()); | ||
| } | ||
| } | ||
| }; | ||
|
|
@@ -366,6 +408,31 @@ struct joint_matrix_mad_impl { | |
| C); | ||
| }; | ||
|
|
||
| template <std::size_t M, std::size_t K, std::size_t N, | ||
| sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutC, | ||
| class BinaryOperation, typename Cond = void> | ||
| struct joint_matrix_bmad_impl { | ||
| sycl::ext::oneapi::experimental::matrix::joint_matrix< | ||
| int32_t, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, | ||
| M, N, LayoutC, sycl::sub_group> | ||
| bmad(sycl::ext::oneapi::experimental::matrix::joint_matrix< | ||
| uint32_t, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, | ||
| K, sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major, | ||
| sycl::sub_group> | ||
| A, | ||
| sycl::ext::oneapi::experimental::matrix::joint_matrix< | ||
| uint32_t, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, | ||
| N, sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major, | ||
| sycl::sub_group> | ||
| B, | ||
| sycl::ext::oneapi::experimental::matrix::joint_matrix< | ||
| int32_t, | ||
| sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, | ||
| N, LayoutC, sycl::sub_group> | ||
| C, | ||
| BinaryOperation Op); | ||
| }; | ||
|
|
||
| template <sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutA, | ||
| sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutB> | ||
| constexpr int get_layout_pair_id(); | ||
|
|
@@ -495,14 +562,59 @@ struct joint_matrix_mad_impl< | |
| get_layout_pair_id<LayoutA, LayoutB>(), 0); | ||
| } | ||
| } | ||
| } else if constexpr (std::is_same<T1, double>::value) { | ||
| } else if constexpr (M == 8 && N == 8 && K == 4) { | ||
|
||
| __dmma_m8n8k4_mma_f64(D.data, A.data, B.data, C.data, | ||
| get_layout_pair_id<LayoutA, LayoutB>(), 0); | ||
| } | ||
| return D; | ||
| } | ||
| }; | ||
|
|
||
| template <std::size_t M, std::size_t K, std::size_t N, | ||
| sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutC, | ||
| class BinaryOperation> | ||
| struct joint_matrix_bmad_impl< | ||
| M, K, N, LayoutC, BinaryOperation, | ||
| typename std::enable_if_t<( | ||
| LayoutC == | ||
| sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major || | ||
| LayoutC == sycl::ext::oneapi::experimental::matrix::matrix_layout:: | ||
| col_major)>> { | ||
| sycl::ext::oneapi::experimental::matrix::joint_matrix< | ||
| int32_t, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, | ||
| M, N, LayoutC, sycl::sub_group> | ||
| bmad(sycl::ext::oneapi::experimental::matrix::joint_matrix< | ||
| uint32_t, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, | ||
| K, sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major, | ||
| sycl::sub_group> | ||
| A, | ||
| sycl::ext::oneapi::experimental::matrix::joint_matrix< | ||
| uint32_t, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, | ||
| N, sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major, | ||
| sycl::sub_group> | ||
| B, | ||
| sycl::ext::oneapi::experimental::matrix::joint_matrix< | ||
| int32_t, | ||
| sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, | ||
| N, LayoutC, sycl::sub_group> | ||
| C, | ||
| BinaryOperation Op) { | ||
| sycl::ext::oneapi::experimental::matrix::joint_matrix< | ||
| int32_t, | ||
| sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, N, | ||
| LayoutC, sycl::sub_group> | ||
| D; | ||
| if constexpr (std::is_same<BinaryOperation, | ||
| sycl::bit_and<uint32_t>>::value) { | ||
| __bmma_m8n8k128_mma_and_popc_b1(D.data, A.data, B.data, C.data, 1); | ||
| } else if constexpr (std::is_same<BinaryOperation, | ||
| sycl::bit_xor<uint32_t>>::value) { | ||
| __bmma_m8n8k128_mma_xor_popc_b1(D.data, A.data, B.data, C.data, 1); | ||
| } | ||
| return D; | ||
| } | ||
| }; | ||
|
|
||
| } // namespace detail | ||
|
|
||
| namespace experimental::matrix { | ||
|
|
@@ -573,6 +685,33 @@ joint_matrix_mad( | |
| #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) | ||
| } | ||
|
|
||
| template <typename Group, std::size_t M, std::size_t K, std::size_t N, | ||
| matrix_layout LayoutC, class BinaryOperation> | ||
| joint_matrix<int32_t, matrix_use::accumulator, M, N, LayoutC, Group> | ||
| joint_matrix_bmad( | ||
| Group sg, | ||
| joint_matrix<uint32_t, matrix_use::a, M, K, matrix_layout::row_major, Group> | ||
| A, | ||
| joint_matrix<uint32_t, matrix_use::b, K, N, matrix_layout::col_major, Group> | ||
| B, | ||
| joint_matrix<int32_t, matrix_use::accumulator, M, N, LayoutC, Group> C, | ||
| BinaryOperation Op) { | ||
| #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) | ||
| return sycl::ext::oneapi::detail::joint_matrix_bmad_impl<M, K, N, LayoutC, | ||
| BinaryOperation>{} | ||
| .bmad(A, B, C, Op); | ||
| #else | ||
| (void)sg; | ||
| (void)A; | ||
| (void)B; | ||
| (void)C; | ||
| (void)Op; | ||
| throw runtime_error("joint_matrix_bmad is " | ||
| "only supported by CUDA devices", | ||
| PI_INVALID_DEVICE); | ||
| #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) | ||
| } | ||
|
|
||
| } // namespace experimental::matrix | ||
| } // namespace oneapi | ||
| } // namespace ext | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,76 @@ | ||
| // REQUIRES: cuda | ||
|
|
||
| // RUN: %clangxx -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s | ||
|
|
||
| #include <CL/sycl.hpp> | ||
|
|
||
| using namespace sycl; | ||
| using namespace sycl::ext::oneapi::experimental::matrix; | ||
|
|
||
| // M, N, (K * 32) define the sizes of dimensions of the three matrix types (a, | ||
| // b, accumulator) used per subgroup operation. | ||
| constexpr int M = 8; // number of rows of accumulator, | ||
| // number of cols of b. | ||
| constexpr int N = 8; // number of cols of accumulator, | ||
| // number of rows of a. | ||
| constexpr int K = 4; // number of cols of a/number of rows of b divided by 32 | ||
|
||
|
|
||
| // Each bit of each uint32_t A/B array element is an element of a single-bit | ||
| // matrix. joint_matrix_bmad performs Binary Dot Products on these matrices (see | ||
| // M. Rastegari et al. Computer Vision – ECCV 2016, 525-542 and A. Li et al. | ||
| // IEEE Transactions on Parallel and Distributed Systems, 32(7):1878-1891, | ||
| // 2021)) | ||
| uint32_t A[M * K]; | ||
| uint32_t B[K * N]; | ||
| int32_t C[M * N]; | ||
| int32_t D[M * N]; | ||
|
|
||
| int main() { | ||
|
|
||
| buffer<uint32_t, 1> bufA(A, range<1>(M * K)); | ||
| buffer<uint32_t, 1> bufB(B, range<1>(K * N)); | ||
| buffer<int32_t, 1> bufC(C, range<1>(M * N)); | ||
| buffer<int32_t, 1> bufD(D, range<1>(M * N)); | ||
|
|
||
| queue q; | ||
|
|
||
| q.submit([&](handler &cgh) { | ||
| auto accC = bufC.get_access<access::mode::read_write>(cgh); | ||
| auto accA = bufA.get_access<access::mode::read_write>(cgh); | ||
| auto accB = bufB.get_access<access::mode::read_write>(cgh); | ||
| auto accD = bufD.get_access<access::mode::read_write>(cgh); | ||
|
|
||
| cgh.parallel_for<class row_col>( | ||
| nd_range<2>({1, 32}, {1, 32}), | ||
| [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { | ||
| sycl::sub_group sg = item.get_sub_group(); | ||
|
|
||
| joint_matrix<int32_t, matrix_use::accumulator, M, N, | ||
| matrix_layout::row_major> | ||
| sub_c; | ||
|
|
||
| joint_matrix<uint32_t, matrix_use::a, M, K, matrix_layout::row_major> | ||
| sub_a; | ||
|
|
||
| joint_matrix<uint32_t, matrix_use::b, K, N, matrix_layout::col_major> | ||
| sub_b; | ||
|
|
||
| //CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m8n8k128.load.c.row.stride.s32.p1i32(i32 addrspace(1)* %_arg_, i32 8) #{{.*}} | ||
| joint_matrix_load(sg, sub_c, accC.get_pointer(), N); | ||
| //CHECK: tail call i32 @llvm.nvvm.wmma.m8n8k128.load.a.row.stride.b1.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 128) #{{.*}} | ||
| joint_matrix_load(sg, sub_a, accA.get_pointer(), K); | ||
| //CHECK: tail call i32 @llvm.nvvm.wmma.m8n8k128.load.b.col.stride.b1.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 128) #{{.*}} | ||
| joint_matrix_load(sg, sub_b, accB.get_pointer(), K); | ||
| //CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m8n8k128.mma.xor.popc.row.col.b1(i32 %3, i32 %4, i32 %1, i32 %2) #{{.*}} | ||
| sub_c = joint_matrix_bmad(sg, sub_a, sub_b, sub_c, | ||
| sycl::bit_xor<uint32_t>()); | ||
| //CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m8n8k128.mma.and.popc.row.col.b1(i32 %3, i32 %4, i32 %6, i32 %7) #{{.*}} | ||
| sub_c = joint_matrix_bmad(sg, sub_a, sub_b, sub_c, | ||
| sycl::bit_and<uint32_t>()); | ||
| //CHECK: tail call void @llvm.nvvm.wmma.m8n8k128.store.d.row.stride.s32.p1i32(i32 addrspace(1)* %_arg_14, i32 %9, i32 %10, i32 8) #{{.*}} | ||
| joint_matrix_store(sg, sub_c, accD.get_pointer(), N); | ||
| }); | ||
| }); | ||
|
|
||
| return 0; | ||
| }; | ||
Uh oh!
There was an error while loading. Please reload this page.