From e0e403f5fc8ccf9ce90fbf770710188d1dcf9285 Mon Sep 17 00:00:00 2001 From: Sergey Dmitriev Date: Tue, 14 Dec 2021 00:51:49 -0800 Subject: [PATCH] [ESIMD] Add support for an arbitrary number of elements to simd::copy_from/to This patch adds support for simd objects with any number of elements to simd::copy_from/to methods. Signed-off-by: Sergey Dmitriev --- .../esimd/detail/simd_obj_impl.hpp | 302 ++++++++++++------ .../intel/experimental/esimd/detail/util.hpp | 15 + sycl/test/esimd/simd_copy_align_flags.cpp | 20 ++ 3 files changed, 243 insertions(+), 94 deletions(-) diff --git a/sycl/include/sycl/ext/intel/experimental/esimd/detail/simd_obj_impl.hpp b/sycl/include/sycl/ext/intel/experimental/esimd/detail/simd_obj_impl.hpp index 914ab154a9005..7c908c57935a3 100644 --- a/sycl/include/sycl/ext/intel/experimental/esimd/detail/simd_obj_impl.hpp +++ b/sycl/include/sycl/ext/intel/experimental/esimd/detail/simd_obj_impl.hpp @@ -577,7 +577,7 @@ template class simd_obj_impl { /// vector_aligned_tag, \p addr must be aligned by simd_obj_impl's vector_type /// alignment. If Flags is overaligned_tag, \p addr must be aligned by N. /// Program not meeting alignment requirements results in undefined behavior. - template >> ESIMD_INLINE void copy_from(const Ty *addr, Flags = {}) SYCL_ESIMD_FUNCTION; @@ -593,6 +593,7 @@ template class simd_obj_impl { /// alignment. If Flags is overaligned_tag, offset must be aligned by N. /// Program not meeting alignment requirements results in undefined behavior. template >> ESIMD_INLINE EnableIfAccessor @@ -606,7 +607,7 @@ template class simd_obj_impl { /// vector_aligned_tag, \p addr must be aligned by simd_obj_impl's vector_type /// alignment. If Flags is overaligned_tag, \p addr must be aligned by N. /// Program not meeting alignment requirements results in undefined behavior. - template >> ESIMD_INLINE void copy_to(Ty *addr, Flags = {}) const SYCL_ESIMD_FUNCTION; @@ -621,6 +622,7 @@ template class simd_obj_impl { /// alignment. If Flags is overaligned_tag, offset must be aligned by N. /// Program not meeting alignment requirements results in undefined behavior. template >> ESIMD_INLINE EnableIfAccessor @@ -733,42 +735,67 @@ template class simd_obj_impl { // ----------- Outlined implementations of simd_obj_impl class APIs. template -template +template void simd_obj_impl::copy_from(const T *Addr, Flags) SYCL_ESIMD_FUNCTION { constexpr unsigned Size = sizeof(T) * N; constexpr unsigned Align = Flags::template alignment; - simd Tmp; + constexpr unsigned BlockSize = OperandSize::OWORD * 8; + constexpr unsigned NumBlocks = Size / BlockSize; + constexpr unsigned RemSize = Size % BlockSize; + if constexpr (Align >= OperandSize::DWORD && Size % OperandSize::OWORD == 0 && - detail::isPowerOf2(Size / OperandSize::OWORD)) { - Tmp = block_load(Addr, Flags{}); + detail::isPowerOf2(RemSize / OperandSize::OWORD)) { + if constexpr (NumBlocks > 0) { + constexpr unsigned BlockN = BlockSize / sizeof(T); + ForHelper::unroll([BlockN, Addr, this](unsigned Block) { + select(Block * BlockN) = + block_load(Addr + (Block * BlockN), Flags{}); + }); + } + if constexpr (RemSize > 0) { + constexpr unsigned RemN = RemSize / sizeof(T); + constexpr unsigned BlockN = BlockSize / sizeof(T); + select(NumBlocks * BlockN) = + block_load(Addr + (NumBlocks * BlockN), Flags{}); + } } else if constexpr (sizeof(T) == 8) { - constexpr unsigned AlignUH = - (N * 4) % Align == 0 ? Align : std::min(Align, 4u); - simd LH(reinterpret_cast(Addr), Flags{}); - simd UH(reinterpret_cast(Addr) + N, - overaligned); - Tmp.template bit_cast_view().template select(0) = LH; - Tmp.template bit_cast_view().template select(N) = UH; - } else if constexpr (N == 1) { - Tmp = *Addr; - } else if constexpr (N == 8 || N == 16 || N == 32) { - simd Offsets(0u, sizeof(T)); - Tmp = gather(Addr, Offsets); + simd BC(reinterpret_cast(Addr), Flags{}); + bit_cast_view() = BC; } else { - constexpr int N1 = N < 8 ? 8 : N < 16 ? 16 : 32; - simd_mask_type Pred(0); - Pred.template select() = 1; - simd Offsets(0u, sizeof(T)); - simd Vals = gather(Addr, Offsets, Pred); - Tmp = Vals.template select(); - } - *this = Tmp.data(); + constexpr unsigned NumChunks = N / ChunkSize; + if constexpr (NumChunks > 0) { + simd Offsets(0u, sizeof(T)); + ForHelper::unroll([Addr, &Offsets, this](unsigned Block) { + select(Block * ChunkSize) = + gather(Addr + (Block * ChunkSize), Offsets); + }); + } + constexpr unsigned RemN = N % ChunkSize; + if constexpr (RemN > 0) { + if constexpr (RemN == 1) { + select<1, 1>(NumChunks * ChunkSize) = Addr[NumChunks * ChunkSize]; + } else if constexpr (RemN == 8 || RemN == 16) { + simd Offsets(0u, sizeof(T)); + select(NumChunks * ChunkSize) = + gather(Addr + (NumChunks * ChunkSize), Offsets); + } else { + constexpr int N1 = RemN < 8 ? 8 : RemN < 16 ? 16 : 32; + simd_mask_type Pred(0); + Pred.template select() = 1; + simd Offsets(0u, sizeof(T)); + simd Vals = + gather(Addr + (NumChunks * ChunkSize), Offsets, Pred); + select(NumChunks * ChunkSize) = + Vals.template select(); + } + } + } } template -template +template ESIMD_INLINE EnableIfAccessor simd_obj_impl::copy_from(AccessorT acc, uint32_t offset, @@ -776,69 +803,125 @@ simd_obj_impl::copy_from(AccessorT acc, uint32_t offset, constexpr unsigned Size = sizeof(T) * N; constexpr unsigned Align = Flags::template alignment; - simd Tmp; + constexpr unsigned BlockSize = OperandSize::OWORD * 8; + constexpr unsigned NumBlocks = Size / BlockSize; + constexpr unsigned RemSize = Size % BlockSize; + if constexpr (Align >= OperandSize::DWORD && Size % OperandSize::OWORD == 0 && - detail::isPowerOf2(Size / OperandSize::OWORD)) { - Tmp = block_load(acc, offset, Flags{}); + detail::isPowerOf2(RemSize / OperandSize::OWORD)) { + if constexpr (NumBlocks > 0) { + constexpr unsigned BlockN = BlockSize / sizeof(T); + ForHelper::unroll([BlockN, acc, offset, this](unsigned Block) { + select(Block * BlockN) = + block_load( + acc, offset + (Block * BlockSize), Flags{}); + }); + } + if constexpr (RemSize > 0) { + constexpr unsigned RemN = RemSize / sizeof(T); + constexpr unsigned BlockN = BlockSize / sizeof(T); + select(NumBlocks * BlockN) = + block_load( + acc, offset + (NumBlocks * BlockSize), Flags{}); + } } else if constexpr (sizeof(T) == 8) { - constexpr unsigned AlignUH = - (N * 4) % Align == 0 ? Align : std::min(Align, 4u); - simd LH(acc, offset, Flags{}); - simd UH(acc, offset + N * 4, overaligned); - Tmp.template bit_cast_view().template select(0) = LH; - Tmp.template bit_cast_view().template select(N) = UH; - } else if constexpr (N == 1 || N == 8 || N == 16 || N == 32) { - simd Offsets(0u, sizeof(T)); - Tmp = gather(acc, Offsets, offset); + simd BC(acc, offset, Flags{}); + bit_cast_view() = BC; } else { - constexpr int N1 = N < 8 ? 8 : N < 16 ? 16 : 32; - simd_mask_type Pred(0); - Pred.template select() = 1; - simd Offsets(0u, sizeof(T)); - simd Vals = gather(acc, Offsets, offset, Pred); - Tmp = Vals.template select(); - } - *this = Tmp.data(); + constexpr unsigned NumChunks = N / ChunkSize; + if constexpr (NumChunks > 0) { + simd Offsets(0u, sizeof(T)); + ForHelper::unroll( + [acc, offset, &Offsets, this](unsigned Block) { + select(Block * ChunkSize) = + gather( + acc, Offsets, offset + (Block * ChunkSize * sizeof(T))); + }); + } + constexpr unsigned RemN = N % ChunkSize; + if constexpr (RemN > 0) { + if constexpr (RemN == 1 || RemN == 8 || RemN == 16) { + simd Offsets(0u, sizeof(T)); + select(NumChunks * ChunkSize) = gather( + acc, Offsets, offset + (NumChunks * ChunkSize * sizeof(T))); + } else { + constexpr int N1 = RemN < 8 ? 8 : RemN < 16 ? 16 : 32; + simd_mask_type Pred(0); + Pred.template select() = 1; + simd Offsets(0u, sizeof(T)); + simd Vals = gather( + acc, Offsets, offset + (NumChunks * ChunkSize * sizeof(T)), Pred); + select(NumChunks * ChunkSize) = + Vals.template select(); + } + } + } } template -template +template void simd_obj_impl::copy_to(T *addr, Flags) const SYCL_ESIMD_FUNCTION { constexpr unsigned Size = sizeof(T) * N; constexpr unsigned Align = Flags::template alignment; + constexpr unsigned BlockSize = OperandSize::OWORD * 8; + constexpr unsigned NumBlocks = Size / BlockSize; + constexpr unsigned RemSize = Size % BlockSize; + + simd Tmp = data(); if constexpr (Align >= OperandSize::OWORD && Size % OperandSize::OWORD == 0 && - detail::isPowerOf2(Size / OperandSize::OWORD)) { - block_store(addr, cast_this_to_derived()); + detail::isPowerOf2(RemSize / OperandSize::OWORD)) { + if constexpr (NumBlocks > 0) { + constexpr unsigned BlockN = BlockSize / sizeof(T); + ForHelper::unroll([BlockN, addr, &Tmp](unsigned Block) { + block_store(addr + (Block * BlockN), + Tmp.template select(Block * BlockN)); + }); + } + if constexpr (RemSize > 0) { + constexpr unsigned RemN = RemSize / sizeof(T); + constexpr unsigned BlockN = BlockSize / sizeof(T); + block_store(addr + (NumBlocks * BlockN), + Tmp.template select(NumBlocks * BlockN)); + } } else if constexpr (sizeof(T) == 8) { - constexpr unsigned AlignUH = - (N * 4) % Align == 0 ? Align : std::min(Align, 4u); - simd Tmp = data(); - simd LH = - Tmp.template bit_cast_view().template select(0); - simd UH = - Tmp.template bit_cast_view().template select(N); - LH.copy_to(reinterpret_cast(addr), Flags{}); - UH.copy_to(reinterpret_cast(addr) + N, overaligned); - } else if constexpr (N == 1) { - *addr = data()[0]; - } else if constexpr (N == 8 || N == 16 || N == 32) { - simd offsets(0u, sizeof(T)); - scatter(addr, offsets, cast_this_to_derived().data()); + simd BC = Tmp.template bit_cast_view(); + BC.copy_to(reinterpret_cast(addr), Flags{}); } else { - constexpr int N1 = N < 8 ? 8 : N < 16 ? 16 : 32; - simd_mask_type pred(0); - pred.template select() = 1; - simd vals(0); - vals.template select() = cast_this_to_derived().data(); - simd offsets(0u, sizeof(T)); - scatter(addr, offsets, vals, pred); + constexpr unsigned NumChunks = N / ChunkSize; + if constexpr (NumChunks > 0) { + simd Offsets(0u, sizeof(T)); + ForHelper::unroll([addr, &Offsets, &Tmp](unsigned Block) { + scatter( + addr + (Block * ChunkSize), Offsets, + Tmp.template select(Block * ChunkSize)); + }); + } + constexpr unsigned RemN = N % ChunkSize; + if constexpr (RemN > 0) { + if constexpr (RemN == 1) { + addr[NumChunks * ChunkSize] = Tmp[NumChunks * ChunkSize]; + } else if constexpr (RemN == 8 || RemN == 16) { + simd Offsets(0u, sizeof(T)); + scatter(addr + (NumChunks * ChunkSize), Offsets, + Tmp.template select(NumChunks * ChunkSize)); + } else { + constexpr int N1 = RemN < 8 ? 8 : RemN < 16 ? 16 : 32; + simd_mask_type Pred(0); + Pred.template select() = 1; + simd Vals(0); + Vals.template select() = + Tmp.template select(NumChunks * ChunkSize); + simd Offsets(0u, sizeof(T)); + scatter(addr + (NumChunks * ChunkSize), Offsets, Vals, Pred); + } + } } } template -template +template ESIMD_INLINE EnableIfAccessor simd_obj_impl::copy_to(AccessorT acc, uint32_t offset, @@ -846,31 +929,62 @@ simd_obj_impl::copy_to(AccessorT acc, uint32_t offset, constexpr unsigned Size = sizeof(T) * N; constexpr unsigned Align = Flags::template alignment; + constexpr unsigned BlockSize = OperandSize::OWORD * 8; + constexpr unsigned NumBlocks = Size / BlockSize; + constexpr unsigned RemSize = Size % BlockSize; + + simd Tmp = data(); if constexpr (Align >= OperandSize::OWORD && Size % OperandSize::OWORD == 0 && - detail::isPowerOf2(Size / OperandSize::OWORD)) { - block_store(acc, offset, cast_this_to_derived()); + detail::isPowerOf2(RemSize / OperandSize::OWORD)) { + if constexpr (NumBlocks > 0) { + constexpr unsigned BlockN = BlockSize / sizeof(T); + ForHelper::unroll([BlockN, acc, offset, &Tmp](unsigned Block) { + block_store( + acc, offset + (Block * BlockSize), + Tmp.template select(Block * BlockN)); + }); + } + if constexpr (RemSize > 0) { + constexpr unsigned RemN = RemSize / sizeof(T); + constexpr unsigned BlockN = BlockSize / sizeof(T); + block_store( + acc, offset + (NumBlocks * BlockSize), + Tmp.template select(NumBlocks * BlockN)); + } } else if constexpr (sizeof(T) == 8) { - constexpr unsigned AlignUH = - (N * 4) % Align == 0 ? Align : std::min(Align, 4u); - simd Tmp = data(); - simd LH = - Tmp.template bit_cast_view().template select(0); - simd UH = - Tmp.template bit_cast_view().template select(N); - LH.copy_to(acc, offset, Flags{}); - UH.copy_to(acc, offset + N * 4, overaligned); - } else if constexpr (N == 1 || N == 8 || N == 16 || N == 32) { - simd offsets(0u, sizeof(T)); - scatter(acc, offsets, cast_this_to_derived().data(), - offset); + simd BC = Tmp.template bit_cast_view(); + BC.copy_to(acc, offset, Flags{}); } else { - constexpr int N1 = N < 8 ? 8 : N < 16 ? 16 : 32; - simd_mask_type pred(0); - pred.template select() = 1; - simd vals(0); - vals.template select() = cast_this_to_derived().data(); - simd offsets(0u, sizeof(T)); - scatter(acc, offsets, vals, offset, pred); + constexpr unsigned NumChunks = N / ChunkSize; + if constexpr (NumChunks > 0) { + simd Offsets(0u, sizeof(T)); + ForHelper::unroll([acc, offset, &Offsets, + &Tmp](unsigned Block) { + scatter( + acc, Offsets, Tmp.template select(Block * ChunkSize), + offset + (Block * ChunkSize * sizeof(T))); + }); + } + constexpr unsigned RemN = N % ChunkSize; + if constexpr (RemN > 0) { + if constexpr (RemN == 1 || RemN == 8 || RemN == 16) { + simd Offsets(0u, sizeof(T)); + scatter( + acc, Offsets, Tmp.template select(NumChunks * ChunkSize), + offset + (NumChunks * ChunkSize * sizeof(T))); + } else { + constexpr int N1 = RemN < 8 ? 8 : RemN < 16 ? 16 : 32; + simd_mask_type Pred(0); + Pred.template select() = 1; + simd Vals(0); + Vals.template select() = + Tmp.template select(NumChunks * ChunkSize); + simd Offsets(0u, sizeof(T)); + scatter(acc, Offsets, Vals, + offset + (NumChunks * ChunkSize * sizeof(T)), + Pred); + } + } } } } // namespace detail diff --git a/sycl/include/sycl/ext/intel/experimental/esimd/detail/util.hpp b/sycl/include/sycl/ext/intel/experimental/esimd/detail/util.hpp index bf47b8868128e..12d25e1dbd9af 100755 --- a/sycl/include/sycl/ext/intel/experimental/esimd/detail/util.hpp +++ b/sycl/include/sycl/ext/intel/experimental/esimd/detail/util.hpp @@ -151,6 +151,21 @@ template <> struct word_type { using type = short; }; template <> struct word_type { using type = ushort; }; template <> struct word_type { using type = ushort; }; +// Utility for compile time loop unrolling. +template class ForHelper { + template static inline void repeat(Action A) { + if constexpr (I < N) + A(I); + if constexpr (I + 1 < N) + repeat(A); + } + +public: + template static inline void unroll(Action A) { + ForHelper::template repeat<0, Action>(A); + } +}; + } // namespace detail } // namespace esimd diff --git a/sycl/test/esimd/simd_copy_align_flags.cpp b/sycl/test/esimd/simd_copy_align_flags.cpp index adc76681bf930..123d1ab0c0957 100644 --- a/sycl/test/esimd/simd_copy_align_flags.cpp +++ b/sycl/test/esimd/simd_copy_align_flags.cpp @@ -38,10 +38,15 @@ template simd test_simd_constructor_vector_aligned(T*); \ template simd test_simd_constructor_overaligned(T*); TEST_SVM_CONSTRUCTOR(char, 2) +TEST_SVM_CONSTRUCTOR(char, 52) TEST_SVM_CONSTRUCTOR(short, 5) +TEST_SVM_CONSTRUCTOR(short, 55) TEST_SVM_CONSTRUCTOR(int, 7) +TEST_SVM_CONSTRUCTOR(int, 57) TEST_SVM_CONSTRUCTOR(float, 14) +TEST_SVM_CONSTRUCTOR(float, 54) TEST_SVM_CONSTRUCTOR(double, 16) +TEST_SVM_CONSTRUCTOR(double, 56) #undef TEST_SVM_CONSTRUCTOR @@ -74,10 +79,15 @@ template simd test_simd_constructor_vector_aligned(accessor test_simd_constructor_overaligned(accessor &); TEST_ACC_CONSTRUCTOR(char, 2) +TEST_ACC_CONSTRUCTOR(char, 52) TEST_ACC_CONSTRUCTOR(short, 5) +TEST_ACC_CONSTRUCTOR(short, 55) TEST_ACC_CONSTRUCTOR(int, 7) +TEST_ACC_CONSTRUCTOR(int, 57) TEST_ACC_CONSTRUCTOR(float, 14) +TEST_ACC_CONSTRUCTOR(float, 54) TEST_ACC_CONSTRUCTOR(double, 16) +TEST_ACC_CONSTRUCTOR(double, 56) #undef TEST_ACC_CONSTRUCTOR @@ -114,10 +124,15 @@ template void test_simd_copy_vector_aligned(T*, T*); \ template void test_simd_copy_overaligned(T*, T*); TEST_SVM_COPY(char, 2) +TEST_SVM_COPY(char, 52) TEST_SVM_COPY(short, 5) +TEST_SVM_COPY(short, 55) TEST_SVM_COPY(int, 7) +TEST_SVM_COPY(int, 57) TEST_SVM_COPY(float, 14) +TEST_SVM_COPY(float, 54) TEST_SVM_COPY(double, 16) +TEST_SVM_COPY(double, 56) #undef TEST_SVM_COPY @@ -154,10 +169,15 @@ template void test_simd_copy_vector_aligned(accessor(accessor &, accessor &); TEST_ACC_COPY(char, 2) +TEST_ACC_COPY(char, 52) TEST_ACC_COPY(short, 5) +TEST_ACC_COPY(short, 55) TEST_ACC_COPY(int, 7) +TEST_ACC_COPY(int, 57) TEST_ACC_COPY(float, 14) +TEST_ACC_COPY(float, 54) TEST_ACC_COPY(double, 16) +TEST_ACC_COPY(double, 56) #undef TEST_ACC_COPY