diff --git a/src/cpu/cpu_engine.hpp b/src/cpu/cpu_engine.hpp index fc5b80794f7..f2f290f8021 100644 --- a/src/cpu/cpu_engine.hpp +++ b/src/cpu/cpu_engine.hpp @@ -47,6 +47,7 @@ #define CPU_INSTANCE_RV64GCV(...) DNNL_RV64GCV_ONLY(CPU_INSTANCE(__VA_ARGS__)) #define CPU_INSTANCE_RV64GCV_ZVFH(...) \ DNNL_RV64GCV_ZVFH_ONLY(CPU_INSTANCE(__VA_ARGS__)) +#define CPU_INSTANCE_PPC64(...) DNNL_PPC64_ONLY(CPU_INSTANCE(__VA_ARGS__)) namespace dnnl { namespace impl { diff --git a/src/cpu/gemm/gemm.cpp b/src/cpu/gemm/gemm.cpp index 6a70ed0dbd8..012a433889d 100644 --- a/src/cpu/gemm/gemm.cpp +++ b/src/cpu/gemm/gemm.cpp @@ -43,7 +43,7 @@ using namespace dnnl::impl::cpu::x64; #elif DNNL_PPC64 -#include "cpu/ppc64/ppc64_gemm_driver.hpp" +#include "cpu/ppc64/gemm/gemm_driver.hpp" using namespace dnnl::impl::cpu::ppc64; #elif DNNL_S390X #include "cpu/s390x/gemm.h" @@ -219,11 +219,10 @@ dnnl_status_t gemm_s8u8s32(const char *transa, const char *transb, } #elif DNNL_PPC64 #ifdef __MMA__ - int ATflag = (*transa == 'T') || (*transa == 't'); - int BTflag = (*transb == 'T') || (*transb == 't'); - return cblas_gemm_s8x8s32_ppc64(ATflag, BTflag, offsetc, *M, *N, *K, *alpha, - A, *LDA, ao, B, *LDB, bo, C, *beta, *LDC, co, 0); + status = gemm_driver(transa, transb, offsetc, M, N, K, alpha, A, LDA, ao, B, + LDB, bo, beta, C, LDC, co, false); + if (status != status::unimplemented) return status; #endif #elif DNNL_S390X #if defined(__VX__) @@ -269,18 +268,9 @@ dnnl_status_t gemm_s8s8s32(const char *transa, const char *transb, #if DNNL_PPC64 #ifdef __MMA__ - int ATflag = (*transa == 'T') || (*transa == 't'); - int BTflag = (*transb == 'T') || (*transb == 't'); - - // Note please that the coercion of "B" and "bo" from int8_t to uint8_t is - // accompanied by the last parameter being set to "1" instead of "0", as - // in the analogous call in the previous routine above. - // This last parameter flags the fact of the coercion, so the called routine - // can process "B" and "bo" appropriately. - - return cblas_gemm_s8x8s32_ppc64(ATflag, BTflag, offsetc, *M, *N, *K, *alpha, - A, *LDA, ao, (const uint8_t *)B, *LDB, (const uint8_t *)bo, C, - *beta, *LDC, co, 1); + status = gemm_driver(transa, transb, offsetc, M, N, K, alpha, A, LDA, ao, B, + LDB, bo, beta, C, LDC, co, false); + if (status != status::unimplemented) return status; #endif #elif DNNL_S390X #if defined(__VX__) diff --git a/src/cpu/matmul/gemm_x8s8s32x_matmul.cpp b/src/cpu/matmul/gemm_x8s8s32x_matmul.cpp index f8afda2f34c..94627ac3d61 100644 --- a/src/cpu/matmul/gemm_x8s8s32x_matmul.cpp +++ b/src/cpu/matmul/gemm_x8s8s32x_matmul.cpp @@ -272,7 +272,8 @@ status_t gemm_x8s8s32x_matmul_t::execute_ref(const exec_ctx_t &ctx) const { const gemm_based::params_t ¶ms = pd()->params(); const bool use_single_gemm_call = pd()->has_runtime_dims_or_strides() ? helper.use_single_gemm_call_optimization(po) - : params.use_single_gemm_call_optimization_; + : ((platform::is_ppc64() && ndims == 2) + || params.use_single_gemm_call_optimization_); bool dst_is_acc = params.dst_is_acc_; int32_t *acc = dst_is_acc ? reinterpret_cast(dst) @@ -297,6 +298,7 @@ status_t gemm_x8s8s32x_matmul_t::execute_ref(const exec_ctx_t &ctx) const { == (1 << (ndims - 1)); std::atomic st(status::success); + if (!use_single_gemm_call) { const int src_mask = utils::get_dims_mask(dst_d.dims(), src_d.dims(), ndims); diff --git a/src/cpu/platform.hpp b/src/cpu/platform.hpp index 2cc142b7a80..43b7f11f17c 100644 --- a/src/cpu/platform.hpp +++ b/src/cpu/platform.hpp @@ -81,7 +81,7 @@ // Helper macros: expand the parameters only on the corresponding architecture. // Equivalent to: #if DNNL_$ARCH ... #endif #define DNNL_X64_ONLY(...) Z_CONDITIONAL_DO(DNNL_X64, __VA_ARGS__) -#define DNNL_PPC64_ONLY(...) Z_CONDITIONAL_DO(DNNL_PPC64_ONLY, __VA_ARGS__) +#define DNNL_PPC64_ONLY(...) Z_CONDITIONAL_DO(DNNL_PPC64, __VA_ARGS__) #define DNNL_S390X_ONLY(...) Z_CONDITIONAL_DO(DNNL_S390X_ONLY, __VA_ARGS__) #define DNNL_AARCH64_ONLY(...) Z_CONDITIONAL_DO(DNNL_AARCH64, __VA_ARGS__) @@ -189,6 +189,15 @@ constexpr int get_cache_line_size() { int get_vector_register_size(); +// Helper to avoid #ifdefs for DNNL_PPC64 +static constexpr bool is_ppc64() { +#if DNNL_PPC64 + return true; +#else + return false; +#endif +} + size_t get_timestamp(); } // namespace platform diff --git a/src/cpu/ppc64/gemm/gemm_driver.cpp b/src/cpu/ppc64/gemm/gemm_driver.cpp new file mode 100644 index 00000000000..63d6791866b --- /dev/null +++ b/src/cpu/ppc64/gemm/gemm_driver.cpp @@ -0,0 +1,1595 @@ +/******************************************************************************* +* Copyright 2022 IBM Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifdef __MMA__ + +#include +#if defined(_MSC_VER) +#include +#endif + +#include +#include + +#include "common/dnnl_traits.hpp" +#include "common/nstl.hpp" +#include "common/utils.hpp" +#include "oneapi/dnnl/dnnl_types.h" + +#include "cpu/platform.hpp" + +#include "cpu/gemm/f32/gemm_utils_f32.hpp" +#include "cpu/gemm/gemm.hpp" + +#include "cpu/ppc64/gemm/gemm_driver.hpp" +#include "cpu/ppc64/gemm/gemm_info.hpp" +#include "cpu/ppc64/gemm/gemm_utils.hpp" +#include "cpu/ppc64/ppc64_gemm_s8x8s32.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace ppc64 { + +#define MAX_STACK_SZ 8192 + +template +struct alignas(64) gemm_per_thread_t { + volatile int32_t result; + volatile int32_t compute_done; + int32_t thr_k_stride; + int32_t nthr_k; + dim_t ldc_local; + dim_t ldc_global; + c_type *c_local; + c_type *volatile c_global; + gemm_slice_t slice; +}; + +template +int get_vector_length() { + int v_bytes = 16; + return v_bytes / sizeof(T); +} + +template +static inline void round_to_nearest(c_type *rounded_val, double fp_val) { + if (fp_val >= 0.) { + fp_val += 0.5; + if (fp_val > INT32_MAX) { fp_val = INT32_MAX; } + } else { + fp_val -= 0.5; + if (fp_val < INT32_MIN) { fp_val = INT32_MIN; } + } + *rounded_val = (c_type)fp_val; +} + +template +static void sum_matrices(dim_t m, dim_t n, mat_t *__restrict dst, dim_t ld_dst, + mat_t *__restrict src, dim_t ld_src) { + + for (dim_t j = 0; j < n; j++) { + PRAGMA_OMP_SIMD() + for (int i = 0; i < m; i++) + dst[i + j * ld_dst] += src[i + j * ld_src]; + } +} + +template +static void sum_k_blocks( + int ithr, gemm_per_thread_t *thread_arg, bool wait) { + auto m = thread_arg[ithr].slice.m; + auto n = thread_arg[ithr].slice.n; + auto ithr_k = thread_arg[ithr].slice.ithr_k; + auto nthr_k = thread_arg[ithr].nthr_k; + auto stride = thread_arg[ithr].thr_k_stride; + dim_t n0, nn; + + partition_1d(ithr_k, nthr_k, n, n0, nn); + + auto get_thread_arg = [&](int thr_k) -> gemm_per_thread_t & { + return thread_arg[ithr + (thr_k - ithr_k) * stride]; + }; + + auto wait_thread = [&](int thr_k) { + if (wait) { + auto &tk_arg = get_thread_arg(thr_k); + while (!tk_arg.compute_done) {} + } + }; + + auto add_thread_results = [&](int thr_k) { + auto &tk_arg = get_thread_arg(thr_k); + sum_matrices(m, nn, tk_arg.c_global + n0 * tk_arg.ldc_global, + tk_arg.ldc_global, tk_arg.c_local + n0 * tk_arg.ldc_local, + tk_arg.ldc_local); + }; + + // First accumulate this thread's results while they are in cache. + if (ithr_k > 0) { + wait_thread(0); + add_thread_results(ithr_k); + } + + // Then accumulate the others. + for (int thr_k = 1; thr_k < nthr_k; thr_k++) { + if (thr_k != ithr_k) { + wait_thread(thr_k); + add_thread_results(thr_k); + } + } +} + +template +static inline void add_results(const dim_t m, const dim_t n, const float alpha, + const float beta, const c_type *c_partial_sum, const dim_t ldcp, + c_type *c_data, const dim_t ldc, const c_type *co, + offset_type offsetc) { + constexpr bool is_int8 = data_traits_t::data_type == data_type::s32; + for (dim_t j = 0; j < n; ++j) { + for (dim_t i = 0; i < m; ++i) { + c_type ctemp = c_partial_sum[i + j * ldcp]; + if (alpha == 1.0f) { + if (beta == 0.0f) { + c_data[i + j * ldc] = ctemp; + } else { + if (is_int8) { + double c_float + = (double)beta * (double)c_data[i + j * ldc]; + c_float += (double)ctemp; + round_to_nearest(&c_data[i + j * ldc], c_float); + } else { + c_data[i + j * ldc] *= beta; + c_data[i + j * ldc] += ctemp; + } + } + } else if (alpha == -1.0f) { + if (beta == 0.0f) { + c_data[i + j * ldc] = -ctemp; + } else { + if (is_int8) { + double c_float + = (double)beta * (double)c_data[i + j * ldc]; + c_float -= (double)ctemp; + round_to_nearest(&c_data[i + j * ldc], c_float); + } else { + c_data[i + j * ldc] *= beta; + c_data[i + j * ldc] -= ctemp; + } + } + } else { + if (beta == 0.0f) { + if (is_int8) { + double c_float = alpha * (double)ctemp; + round_to_nearest(&c_data[i + j * ldc], c_float); + } else { + c_data[i + j * ldc] = alpha * ctemp; + } + } else { + if (is_int8) { + double c_float = alpha * (double)ctemp + + beta * (double)c_data[i + j * ldc]; + round_to_nearest(&c_data[i + j * ldc], c_float); + } else { + c_data[i + j * ldc] *= beta; + c_data[i + j * ldc] += alpha * ctemp; + } + } + } + if (offsetc == offset_type::fixed) { + c_data[i + j * ldc] += co[0]; + } else if (offsetc == offset_type::row) { + c_data[i + j * ldc] += co[j]; + } else if (offsetc == offset_type::column) { + c_data[i + j * ldc] += co[i]; + } + } + } +} + +template +static inline dim_t get_k_padd( + int ithr, dim_t k, const gemm_info_t *arg) { + if (arg->a_packed) { + dim_t block_m, block_k; + arg->a_packed->get_blocking(ithr, block_m, block_k); + return block_k; + } else if (arg->b_packed) { + dim_t block_n, block_k; + arg->b_packed->get_blocking(ithr, block_k, block_n); + return block_k; + } else { + dim_t k_padd = 0; + + if (k <= arg->bk_traditional) { + k_padd = utils::rnd_up(k, arg->uk); + k_padd = nstl::max(dim_t(128), k_padd); + } else if (k < 2 * arg->bk) + k_padd = utils::rnd_up((k + 1) / 2, arg->uk); + else + k_padd = arg->bk; + + return k_padd; + } +} + +template +static inline dim_t get_m_padd( + int ithr, dim_t m, const gemm_info_t *arg) { + if (arg->a_packed) { + dim_t block_m, block_k; + arg->a_packed->get_blocking(ithr, block_m, block_k); + return block_m; + } else + return utils::rnd_up( + nstl::min(nstl::max(m, arg->um), arg->bm), arg->um); +} + +template +static inline dim_t get_m_padd_parallel_a(int ithr, dim_t m, + const gemm_info_t *arg, int nthrs) { + auto m_padd = get_m_padd(ithr, m, arg); + + if (!arg->a_packed) { + constexpr auto multiplier = 10; + + m_padd *= nstl::min(nthrs, multiplier); + if (m_padd > m) m_padd = utils::rnd_up(m, arg->um); + } + + return m_padd; +} + +template +static inline dim_t get_n_padd(int ithr, dim_t n, dim_t k, + const gemm_info_t *arg) { + if (arg->b_packed) { + dim_t block_n, block_k; + arg->b_packed->get_blocking(ithr, block_k, block_n); + return block_n; + } else { + auto bn = (k < arg->blocking_small_k) ? arg->bn_small_k : arg->bn; + return utils::rnd_up(nstl::min(nstl::max(n, arg->un), bn), arg->un); + } +} + +static inline void *align(void *ptr, size_t alignment) { + return (void *)utils::rnd_up((uintptr_t)ptr, alignment); +} + +template +void gemm_kernel(dim_t m, dim_t n, const dim_t k, const float alpha, + const a_type *a, const uint8_t *b, float beta, c_type *c, + const dim_t ldc, const c_type *a_row_sum, const c_type *b_col_sum, + c_type *row_offset_ws, c_type *col_offset_ws, const c_type *co, + offset_type offsetc, const gemm_info_t *arg) { + + bool col_req = false; + bool row_req = false; + + constexpr bool is_int8 = utils::one_of( + data_traits_t::data_type, data_type::s8, data_type::u8); + + dim_t m_stk = col_offset_ws ? 1 : m; + dim_t n_stk = row_offset_ws ? 1 : n; + +#if 1 + std::vector col_offset_stk_vec(m_stk); + std::vector row_offset_stk_vec(n_stk); + c_type *col_offset_stk = col_offset_stk_vec.data(); + c_type *row_offset_stk = row_offset_stk_vec.data(); +#else + c_type *col_offset_stk = nullptr; + if (!col_offset_ws) + col_offset_stk = (c_type *)alloca(sizeof *col_offset_stk * m_stk); + + c_type *row_offset_stk = nullptr; + if (!row_offset_ws) + row_offset_stk = (c_type *)alloca(sizeof *row_offset_stk * n_stk); +#endif + + // Use the heap if already allocated and stack otherwise. + c_type *col_offset = col_offset_ws ? col_offset_ws : col_offset_stk; + c_type *row_offset = row_offset_ws ? row_offset_ws : row_offset_stk; + + if (is_int8) { + c_type ao = arg->ao; + c_type bo = arg->bo; + c_type co_0 = offsetc == offset_type::none ? 0 : co[0]; + + if (bo != 0 || offsetc == offset_type::column) col_req = true; + if (ao != 0 || offsetc == offset_type::row) row_req = true; + + // It needs one of column or row offsets, but it doesn't need both + if ((ao != 0 && bo != 0) + || (offsetc == offset_type::fixed && co_0 != 0)) { + if (!col_req && !row_req) { + if (m <= n) { + col_req = true; + } else { + row_req = true; + } + } + } + + if (col_req) { + for (dim_t i = 0; i < m; i++) + col_offset[i] = 0; + + if (offsetc == offset_type::column) { + for (dim_t i = 0; i < m; i++) + col_offset[i] += co[i]; + } + + if (bo != 0 && a_row_sum) { + for (dim_t i = 0; i < m; i++) + col_offset[i] -= bo * a_row_sum[i]; + } + } + + if (row_req) { + for (dim_t i = 0; i < n; i++) + row_offset[i] = 0; + + if (offsetc == offset_type::row) { + for (dim_t i = 0; i < n; i++) + row_offset[i] += co[i]; + } + + if (ao != 0 && b_col_sum) { + for (dim_t i = 0; i < n; i++) + row_offset[i] -= ao * b_col_sum[i]; + } + } + + if (offsetc == offset_type::fixed && co_0 != 0) { + if (col_req) { + for (dim_t i = 0; i < m; i++) + col_offset[i] += co_0; + } else { + for (dim_t i = 0; i < n; i++) + row_offset[i] += co_0; + } + } + + if (ao != 0 && bo != 0) { + if (col_req) { + for (dim_t i = 0; i < m; i++) + col_offset[i] += (c_type)k * ao * bo; + } else { + for (dim_t i = 0; i < n; i++) + row_offset[i] += (c_type)k * ao * bo; + } + } + } + + /* Column and row offsets are ignored by non-integer compute kernels. + * Scaling is done only for bfloat16 kernels. + */ + if (m > 0 && n > 0) + gemm_kernel_8bit(m, n, k, alpha, const_cast(a), + const_cast(b), c, beta, ldc); + // Adding the row & col sums. + for (dim_t j = 0; j < n; j++) { + for (dim_t i = 0; i < m; i++) { + if (row_req) c[i + j * ldc] += row_offset[j]; + if (col_req) c[i + j * ldc] += col_offset[i]; + } + } +} + +template +static dnnl_status_t gemm_kernel_driver(int ithr, dim_t m, dim_t n, dim_t k, + const a_type *a, const b_type *b, float beta, c_type *c, dim_t ldc, + offset_type offsetc, const c_type *co, + const gemm_info_t *arg) { + if (m <= 0 || n <= 0) return dnnl_success; + + dim_t lda = arg->lda; + dim_t ldb = arg->ldb; + + float alpha = arg->alpha; + + constexpr bool is_int8 = utils::one_of( + data_traits_t::data_type, data_type::s8, data_type::u8); + + const std::shared_ptr &a_packed = arg->a_packed; + const std::shared_ptr &b_packed = arg->b_packed; + + // Scaling C matrix. + if (!is_int8 && beta != 1.0f && beta != 0.0f) { beta = 1.0f; } + + // Quick exit for C = beta * C + if (!is_int8 && alpha == 0.0f) { return dnnl_success; } + + // Get block sizes. + dim_t k_padd = get_k_padd(ithr, k, arg); + dim_t m_padd = get_m_padd(ithr, m, arg); + dim_t n_padd = get_n_padd(ithr, n, k, arg); + + // Padding for temporary buffer for C + dim_t ldc_buf = gemm_utils::get_ld_padd(m_padd); + + dim_t strideAm = (arg->transa == no_trans) ? 1 : lda; + dim_t strideAn = (arg->transa != no_trans) ? 1 : lda; + dim_t strideBm = (arg->transb == no_trans) ? 1 : ldb; + dim_t strideBn = (arg->transb != no_trans) ? 1 : ldb; + + size_t a_buf_nelems = m_padd * k_padd; + size_t b_buf_nelems = k_padd * n_padd; + + size_t a_row_sum_nelems = m_padd; + size_t b_col_sum_nelems = n_padd; + + if (a_packed) a_buf_nelems = a_row_sum_nelems = 0; + if (b_packed) b_buf_nelems = b_col_sum_nelems = 0; + + size_t mem_size = a_buf_nelems * sizeof(*a) + PAGE_4K + + b_buf_nelems * sizeof(*b) + PAGE_4K; + + if (is_int8) { + mem_size += a_row_sum_nelems * sizeof(*c) + PAGE_4K + + b_col_sum_nelems * sizeof(*c) + PAGE_4K; + } + + size_t col_offset_ws_nelems = arg->um; + size_t row_offset_ws_nelems = n_padd; + size_t stk_sz = (col_offset_ws_nelems + row_offset_ws_nelems) * sizeof(*c); + const bool use_stack = is_int8 && stk_sz <= MAX_STACK_SZ; + if (!use_stack) { + mem_size += col_offset_ws_nelems * sizeof(*c) + PAGE_4K; + mem_size += row_offset_ws_nelems * sizeof(*c) + PAGE_4K; + } + + bool need_c_buffer + = (is_int8 && (alpha != 1.0f || (beta != 1.0f && beta != 0.0f))); + + if (need_c_buffer) { + size_t c_buf_nelems = ldc_buf * n_padd; + mem_size += c_buf_nelems * sizeof(*c) + PAGE_4K; + } + + char *mem = nullptr; + + if (mem_size > 0) { + mem = (char *)malloc(mem_size, 128); + if (!mem) return dnnl_out_of_memory; + } + + a_type *bufferA = (a_type *)align(mem, PAGE_4K); + void *p_next_buf = bufferA + a_buf_nelems; + + uint8_t *bufferB = (uint8_t *)align(p_next_buf, PAGE_4K); + p_next_buf = bufferB + b_buf_nelems; + + c_type *a_row_sum = nullptr; + c_type *b_col_sum = nullptr; + if (is_int8) { + a_row_sum = (c_type *)align(p_next_buf, PAGE_4K); + p_next_buf = a_row_sum + a_row_sum_nelems; + + b_col_sum = (c_type *)align(p_next_buf, PAGE_4K); + p_next_buf = b_col_sum + b_col_sum_nelems; + } + + c_type *col_offset_ws = nullptr; + c_type *row_offset_ws = nullptr; + if (!use_stack) { + col_offset_ws = (c_type *)align(p_next_buf, PAGE_4K); + p_next_buf = col_offset_ws + col_offset_ws_nelems; + + row_offset_ws = (c_type *)align(p_next_buf, PAGE_4K); + p_next_buf = row_offset_ws + row_offset_ws_nelems; + } + + c_type *bufferC = nullptr; + if (need_c_buffer) bufferC = (c_type *)align(p_next_buf, PAGE_4K); + + int a_block_copied = 0; + dim_t sizeM = 0; + for (dim_t Bm = 0; Bm < m; Bm += sizeM) { + sizeM = m - Bm; + if (sizeM > m_padd) sizeM = m_padd; + + dim_t sizeK = 0; + dim_t blk_k = 0; + for (dim_t Bk = 0; Bk < k; Bk += sizeK, blk_k++) { + sizeK = k - Bk; + if (sizeK > k_padd) sizeK = k_padd; + + // Scale C blocks by beta only for the first time + auto beta_eff = (Bk == 0) ? beta : 1.0f; + + // Apply C offset when to the last k-block of the partial sum. + auto offsetc_eff = offset_type::none; + if (Bk + sizeK == k) offsetc_eff = offsetc; + + dim_t sizeN = 0; + for (dim_t Bn = 0; Bn < n; Bn += sizeN) { + sizeN = n - Bn; + if (sizeN > n_padd) sizeN = n_padd; + if (b_packed) { + bufferB = b_packed->matrix(ithr, Bk, Bn); + if (is_int8) + b_col_sum = b_packed->col_sums(ithr, blk_k, Bn); + } else { + const b_type *b_block = b + Bk * strideBm + Bn * strideBn; + /* Column sum argument is ignored for non-integer kernels + * and scaling factor is ignored by 8-bit and 16-bit copy + * kernels. + */ + if (arg->transb) { + packB_T8_8bit(sizeK, sizeN, b_block, ldb, + bufferB, arg->b_is_signed); + } else { + packB_N8bit(sizeK, sizeN, b_block, ldb, bufferB, + arg->b_is_signed); + } + // Currently calculating the b_col sum here only to check unit test ase passed or not + if (arg->ao != 0) { + if (arg->transb == 1) { + for (int i = 0; i < sizeN; i++) { + int sum = 0; + for (int j = 0; j < sizeK; j++) { + if (arg->b_is_signed) { + sum += b_block[j * ldb + i] + 128; + } else { + sum += b_block[j * ldb + i]; + } + } + b_col_sum[i] = sum; + } + } else { + for (int i = 0; i < sizeN; i++) { + int sum = 0; + for (int j = 0; j < sizeK; j++) { + if (arg->b_is_signed) { + sum += b_block[i * ldb + j] + 128; + } else { + sum += b_block[i * ldb + j]; + } + } + b_col_sum[i] = sum; + } + } + } + } + + dim_t sizeUM = 0; + for (dim_t Um = 0; Um < sizeM; Um += sizeUM) { + sizeUM = sizeM - Um; + if (sizeUM > arg->um) sizeUM = arg->um; + + /* Use the whole A buffer only if we have multiple B + * blocks for k-dimension, otherwise we are wasting cache + * to store B and C blocks. + */ + dim_t Um_forA = 0; + if (sizeN < n) Um_forA = Um; + + a_type *bufferA_eff = nullptr; + c_type *a_row_sum_eff = nullptr; + if (a_packed) { + Um_forA = Um; + + // TODO Can we simplify this! + dim_t buf_shift = 0; + buf_shift = Um_forA * sizeK; + bufferA_eff = a_packed->matrix(ithr, Bm, Bk) + + buf_shift; + + if (is_int8) + a_row_sum_eff = a_packed->row_sums( + ithr, Bm, blk_k) + + Um_forA; + } else { + // TODO Can we simplify this! + dim_t buf_shift = 0; + buf_shift = Um_forA * ((sizeK + 3) & (~3)); + + bufferA_eff = bufferA + buf_shift; + a_row_sum_eff + = a_row_sum ? a_row_sum + Um_forA : nullptr; + + if (!a_block_copied) { + const a_type *a_block + = a + (Bm + Um) * strideAm + Bk * strideAn; + if (arg->transa) { + pack_N16_8bit_V2_lxvp<__vector signed char>( + sizeK, sizeUM, a_block, lda, + bufferA_eff, a_row_sum_eff); + } else { + pack_T16_8bit_V2<__vector signed char>(sizeK, + sizeUM, a_block, lda, bufferA_eff, + a_row_sum_eff); + } + } + } + + c_type *c_block = c + (Bm + Um) + Bn * ldc; + + dim_t co_stride = 0; + if (offsetc_eff == offset_type::row) + co_stride = Bn; + else if (offsetc_eff == offset_type::column) + co_stride = Bm + Um; + + if (need_c_buffer) { + gemm_kernel(sizeUM, sizeN, sizeK, 1.0f, bufferA_eff, + bufferB, 0.0f, bufferC + Um, ldc_buf, + a_row_sum_eff, b_col_sum, row_offset_ws, + col_offset_ws, (c_type *)nullptr, + offset_type::none, arg); + add_results(sizeUM, sizeN, alpha, beta_eff, + bufferC + Um, ldc_buf, c_block, ldc, + co + co_stride, offsetc_eff); + } else { + gemm_kernel(sizeUM, sizeN, sizeK, alpha, bufferA_eff, + bufferB, beta_eff, c_block, ldc, a_row_sum_eff, + b_col_sum, row_offset_ws, col_offset_ws, + co + co_stride, offsetc_eff, arg); + } + } + a_block_copied = 1; + } + a_block_copied = 0; + } + } + free(mem); + return dnnl_success; +} + +template +static dnnl_status_t kernel_driver_parallel_acopiedbcopy(int ithr, dim_t m, + dim_t n, dim_t k, dim_t blk_k, dim_t Bk, const a_type *bufferA, + const b_type *b, float beta, c_type *c, offset_type offsetc, + const c_type *co, const c_type *a_row_sum, + const gemm_info_t *arg) { + + dim_t ldb = arg->ldb; + dim_t ldc = arg->ldc; + + float alpha = arg->alpha; + + const std::shared_ptr &b_packed = arg->b_packed; + + if (m <= 0 || n <= 0) { return dnnl_success; } + + // Padding along N dimension. + dim_t n_padd = get_n_padd(ithr, n, k, arg); + + // Padding for temporary buffer for C + dim_t ldc_buf = gemm_utils::get_ld_padd(m); + + dim_t strideBn = (arg->transb != 0) ? 1 : ldb; + + size_t b_buf_nelems = k * n_padd; + size_t b_col_sum_nelems = n_padd; + constexpr bool is_int8 = utils::one_of( + data_traits_t::data_type, data_type::s8, data_type::u8); + if (b_packed) b_buf_nelems = b_col_sum_nelems = 0; + + size_t mem_size = b_buf_nelems * sizeof(*b) + PAGE_4K; + + if (is_int8) { mem_size += b_col_sum_nelems * sizeof(*c) + PAGE_4K; } + + size_t col_offset_ws_nelems = m; + size_t row_offset_ws_nelems = n_padd; + size_t stk_sz = (col_offset_ws_nelems + row_offset_ws_nelems) * sizeof(*c); + const bool use_stack = is_int8 && stk_sz <= MAX_STACK_SZ; + if (!use_stack) { + mem_size += col_offset_ws_nelems * sizeof(*c) + PAGE_4K; + mem_size += row_offset_ws_nelems * sizeof(*c) + PAGE_4K; + } + + bool need_c_buffer + = (is_int8 && (alpha != 1.0f || (beta != 1.0f && beta != 0.0f))); + + if (need_c_buffer) { + size_t c_buf_nelems = ldc_buf * n_padd; + mem_size += c_buf_nelems * sizeof(*c) + PAGE_4K; + } + + char *mem = nullptr; + + if (mem_size > 0) { + mem = (char *)malloc(mem_size, 128); + if (!mem) return dnnl_out_of_memory; + } + + // For Int8 kernel packed b type will always be uint8 + // because s8u8 support on Power10 MMA + uint8_t *bufferB = (uint8_t *)align(mem, PAGE_4K); + void *p_next_buf = bufferB + b_buf_nelems; + + c_type *b_col_sum = nullptr; + if (is_int8) { + b_col_sum = (c_type *)align(p_next_buf, PAGE_4K); + p_next_buf = b_col_sum + b_col_sum_nelems; + } + + c_type *col_offset_ws = nullptr; + c_type *row_offset_ws = nullptr; + if (!use_stack) { + col_offset_ws = (c_type *)align(p_next_buf, PAGE_4K); + p_next_buf = col_offset_ws + col_offset_ws_nelems; + + row_offset_ws = (c_type *)align(p_next_buf, PAGE_4K); + p_next_buf = row_offset_ws + row_offset_ws_nelems; + } + + c_type *bufferC = nullptr; + if (need_c_buffer) bufferC = (c_type *)align(p_next_buf, PAGE_4K); + + dim_t sizeN = 0; + for (dim_t Bn = 0; Bn < n; Bn += sizeN) { + sizeN = n - Bn; + if (sizeN > n_padd) sizeN = n_padd; + + if (b_packed) { + bufferB = b_packed->matrix(ithr, Bk, Bn); + if (is_int8) + b_col_sum = b_packed->col_sums(ithr, blk_k, Bn); + } else { + const b_type *b_block = b + Bn * strideBn; + + /* Column sum argument is ignored for non-integer kernels and + * scaling factor is ignored by 8-bit and 16-bit copy kernels. + */ + if (arg->transb) { + packB_T8_8bit( + k, sizeN, b_block, ldb, bufferB, arg->b_is_signed); + } else { + packB_N8bit( + k, sizeN, b_block, ldb, bufferB, arg->b_is_signed); + } + if (arg->ao != 0) { + if (arg->transb == 1) { + for (int i = 0; i < sizeN; i++) { + int sum = 0; + for (int j = 0; j < k; j++) { + if (arg->b_is_signed) { + sum += b_block[j * ldb + i] + 128; + } else { + sum += b_block[j * ldb + i]; + } + } + b_col_sum[i] = sum; + } + } else { + for (int i = 0; i < sizeN; i++) { + int sum = 0; + for (int j = 0; j < k; j++) { + if (arg->b_is_signed) { + sum += b_block[i * ldb + j] + 128; + } else { + sum += b_block[i * ldb + j]; + } + } + b_col_sum[i] = sum; + } + } + } + } + + dim_t co_stride = 0; + if (offsetc == offset_type::fixed) { + co_stride = 0; + } else if (offsetc == offset_type::row) { + co_stride = Bn; + } else if (offsetc == offset_type::column) { + co_stride = 0; + } + + c_type *c_block = c + Bn * ldc; + if (need_c_buffer) { + gemm_kernel(m, sizeN, k, 1.0f, bufferA, bufferB, 0.0f, bufferC, + ldc_buf, a_row_sum, b_col_sum, row_offset_ws, col_offset_ws, + (c_type *)nullptr, offset_type::none, arg); + add_results(m, sizeN, alpha, beta, bufferC, ldc_buf, c_block, ldc, + co + co_stride, offsetc); + } else { + gemm_kernel(m, sizeN, k, alpha, bufferA, bufferB, beta, c_block, + ldc, a_row_sum, b_col_sum, row_offset_ws, col_offset_ws, + co + co_stride, offsetc, arg); + } + } + + free(mem); + + return dnnl_success; +} + +template +static inline void set_thread_opts_nopack(int nthrs, int nthrs_spawn, + gemm_threading_t &thread_info, + const gemm_info_t *arg) { + + static constexpr dim_t N2D_MAX = 384; + static constexpr dim_t M2D_MIN = 384; + + constexpr bool is_int8 = utils::one_of( + data_traits_t::data_type, data_type::s8, data_type::u8); + + dim_t m = arg->m; + dim_t n = arg->n; + dim_t k = arg->k; + + thread_info.nthrs_m = 0; + thread_info.nthrs_n = 0; + thread_info.nthrs_k = 0; + thread_info.copy = copy_type::nonshared; + thread_info.partition = partition_type::row_1d; + + // TODO Check if we can use dynamic scheduling for sgemm. + // TODO Check if we should use 3D blocking. + thread_info.nthrs_k = 1; + thread_info.thread_k = k; + + bool condition_2D_bsrc = false; + int scale = nthrs; + condition_2D_bsrc = (256 * m > scale * n) && (scale * m < 256 * n); + + // TODO Check if we should use k-partitioning. + + int condition_1D_copya = false; + if (m >= 1000 && n >= 4000) { + condition_2D_bsrc = false; + condition_1D_copya = true; + } + + // If A or B offset is non-zero, we need to keep 1D_copya to reduce update + // overhead. + // TODO: the reasons seems to be in copy_sum_bx routines. At least, + // after simple optimization of copy_sum_ax for avx512, similar + // restriction on offset B became unnecessary. Revisit. + if (is_int8 && arg->ao != 0 && (arg->bo != 0)) { + condition_2D_bsrc = false; + condition_1D_copya = true; + } + + if (condition_2D_bsrc) { + int nthrs_m = 1; + int nthrs_n = nthrs; + + if (m == 800 && n == 300) { + // TODO: Expand this branch to other problem sizes. + + auto &thread_m = thread_info.thread_m; + auto &thread_n = thread_info.thread_n; + + const dim_t block_m = arg->um * 4; + constexpr dim_t block_n = 64; + constexpr dim_t small_m = 16; + constexpr dim_t small_n = 2; + + std::tie(nthrs_m, nthrs_n) = gemm_utils::calc_nthr_2d(nthrs, m, n, + block_m, block_n, small_m, small_n, thread_m, thread_n); + + thread_info.nthrs_m = nthrs_m; + thread_info.nthrs_n = nthrs_n; + thread_info.partition = partition_type::mnk_3d; + + } else if ((n <= 64 || n >= 256)) { + while (((nthrs_n > 1) && (n / nthrs_n < arg->un) + && (m / nthrs_m >= 2 * arg->um)) + || ((nthrs_n % 2 == 0) + && (n / nthrs > N2D_MAX + || n / nthrs_n <= N2D_MAX / 2) + && (m / nthrs_m >= 2 * M2D_MIN) && (nthrs_m < 4))) { + nthrs_m *= 2; + nthrs_n /= 2; + } + + thread_info.nthrs_m = nthrs_m; + thread_info.nthrs_n = nthrs_n; + thread_info.partition = partition_type::col_major_2d; + } else { + // Use 3D decomposition from pack api without k-partitioning. + set_thread_opts_pack(nthrs, thread_info, arg, false); + } + + } else if (condition_1D_copya && dnnl_thr_syncable()) { + // Use parallel copy A algorithm + thread_info.copy = copy_type::shared_a; + thread_info.partition = partition_type::col_1d; + thread_info.nthrs_m = 1; + thread_info.nthrs_n = nthrs_spawn; // Using all spawned threads. + } else { + auto veclen = get_vector_length(); + + if (m > n && (m >= nthrs * veclen || n < nthrs)) { + if (n <= 20 && is_int8) { + // Use 3D decomposition forcing m-blocking only. + set_thread_opts_pack( + nthrs, thread_info, arg, false, true, false); + } else { + thread_info.partition = partition_type::row_1d; + thread_info.nthrs_m = nthrs; + thread_info.nthrs_n = 1; + } + } else { + thread_info.partition = partition_type::col_1d; + thread_info.nthrs_m = 1; + thread_info.nthrs_n = nthrs; + } + } +} + +template +static inline void set_thread_opts_pack(int nthrs, + gemm_threading_t &thread_info, + const gemm_info_t *arg, + bool do_k_blocking = true, bool do_m_blocking = true, + bool do_n_blocking = true) { + + constexpr bool is_int8 = utils::one_of( + data_traits_t::data_type, data_type::s8, data_type::u8); + constexpr bool is_bf16 + = data_traits_t::data_type == data_type::bf16; + + bool do_m_blocking_only = do_m_blocking && !do_n_blocking; + + auto m = arg->m, n = arg->n, k = arg->k; + + auto &nthr_m = thread_info.nthrs_m; + auto &nthr_n = thread_info.nthrs_n; + auto &nthr_k = thread_info.nthrs_k; + auto &thread_m = thread_info.thread_m; + auto &thread_n = thread_info.thread_n; + auto &thread_k = thread_info.thread_k; + auto &block_m = thread_info.block_m; + auto &block_n = thread_info.block_n; + auto &block_k = thread_info.block_k; + + constexpr auto MBLK = 64; + constexpr auto NBLK = 64; + auto KBLK = is_int8 ? 3072 : 256; + KBLK = do_m_blocking_only && is_int8 ? 384 : KBLK; + + nthr_m = nthr_n = nthr_k = 1; + thread_info.copy = copy_type::nonshared; + thread_info.partition = partition_type::mnk_3d; + + auto choose_blocking + = [](dim_t size_z, dim_t &thread_z, int &nthr_z, dim_t block_z_init, + dim_t &block_z, dim_t block_align) { + thread_z = utils::div_up(size_z, nthr_z); + auto num_blk = utils::div_up(thread_z, block_z_init); + block_z = utils::div_up(thread_z, num_blk); + block_z = utils::rnd_up(block_z, block_align); + thread_z = num_blk * block_z; + if (thread_z * nthr_z > size_z) + nthr_z = utils::div_up(size_z, thread_z); + }; + + auto choose_m_blocking = [&]() { + auto align = get_vector_length(); + align = do_m_blocking_only ? arg->um : align; + choose_blocking(m, thread_m, nthr_m, arg->bm, block_m, align); + }; + auto choose_n_blocking = [&]() { + choose_blocking(n, thread_n, nthr_n, arg->bn, block_n, arg->un); + }; + auto choose_k_blocking = [&]() { + auto align = nstl::max(arg->uk, dim_t(4)); + choose_blocking(k, thread_k, nthr_k, arg->bk, block_k, align); + }; + + // Choose k blocking. + if ((m / MBLK + n / NBLK) < nthrs && do_k_blocking) { + for (int nk = 1; nk <= 4 && k >= ((KBLK + 1) * nk); nk++) + if (nthrs % nk == 0) nthr_k = nk; + + // Sacrifice one thread and try again if parallelism is too small in + // n-dimension. + if (nthr_k == 1 && nthrs > 1 && do_m_blocking_only) { + nthrs--; + for (int nk = 1; nk <= 4 && k >= ((KBLK + 1) * nk); nk++) + if (nthrs % nk == 0) nthr_k = nk; + } + + // Allow up to 2 threads to be sacrificed for large k >> m, n. + if (nthr_k < 4 && k >= m * 4 && k >= n * 4 && nthrs > 10 && is_bf16) { + for (int nk = 1; nk <= 4 && k >= ((KBLK + 1) * nk); nk++) + if (nthrs % nk <= 2) nthr_k = nk; + } + } + + choose_k_blocking(); + + // Choose m/n blocking. + auto min_mblk = arg->um; + min_mblk = do_m_blocking ? min_mblk : m; + min_mblk = do_m_blocking_only ? arg->um : min_mblk; + auto min_nblk = do_n_blocking ? NBLK / 2 : n; + + std::tie(nthr_m, nthr_n) = partition_2d_minblk(m, n, MBLK, NBLK, min_mblk, + min_nblk, arg->um, arg->un, nthrs / nthr_k, + do_m_blocking && do_n_blocking && do_k_blocking); + + auto nthr_m_init = nthr_m, nthr_n_init = nthr_n; + + choose_m_blocking(); + choose_n_blocking(); + + if (is_int8 && do_m_blocking && do_n_blocking) { + // If we lost a thread in one dimension because we padded the blocking + // size, try to rebalance the other dimensions. + if ((nthr_n != nthr_n_init) + && ((nthr_m + 1) * nthr_n * nthr_k <= nthrs)) { + nthr_m++; + choose_m_blocking(); + } + + if ((nthr_m != nthr_m_init) + && (nthr_m * (nthr_n + 1) * nthr_k <= nthrs)) { + nthr_n++; + choose_n_blocking(); + } + } +} + +template +static inline int set_thread_opts(int nthrs, int nthrs_spawn, + gemm_threading_t &thread_info, + const gemm_info_t *arg) { + + thread_info.block_m = thread_info.block_n = thread_info.block_k = -1; + thread_info.thread_m = thread_info.thread_n = thread_info.thread_k = -1; + + constexpr bool is_int8 = utils::one_of( + data_traits_t::data_type, data_type::s8, data_type::u8); + + if (arg->packing != pack_type::none && (is_int8)) + set_thread_opts_pack(nthrs, thread_info, arg); + else + set_thread_opts_nopack(nthrs, nthrs_spawn, thread_info, arg); + + return thread_info.nthrs_m * thread_info.nthrs_n * thread_info.nthrs_k; +} + +template +static inline std::tuple +decompose_matrices(const gemm_slice_t &slice, + const gemm_info_t *arg) { + + dim_t stride_am = (arg->transa == no_trans) ? 1 : arg->lda; + dim_t stride_ak = (arg->transa != no_trans) ? 1 : arg->lda; + dim_t stride_bn = (arg->transb != no_trans) ? 1 : arg->ldb; + dim_t stride_bk = (arg->transb == no_trans) ? 1 : arg->ldb; + + auto a = arg->a; + auto b = arg->b; + auto c = arg->c; + if (a) a += slice.off_m * stride_am + slice.off_k * stride_ak; + if (b) b += slice.off_n * stride_bn + slice.off_k * stride_bk; + if (c) c += slice.off_m + slice.off_n * arg->ldc; + + dim_t co_stride; + switch (arg->offsetc) { + case offset_type::row: co_stride = slice.off_n; break; + case offset_type::column: co_stride = slice.off_m; break; + default: co_stride = 0; break; + } + auto co = arg->co; + if (co) co += co_stride; + + return std::make_tuple(a, b, c, co); +} + +template +static dnnl_status_t parallel_a_copy(const int ithr, const int nthrs, + const dim_t m, const dim_t n, const dim_t k, const a_type *a, + const b_type *b, float beta, c_type *c, dim_t ldc, offset_type offsetc, + const c_type *co, const gemm_info_t *arg, + char **p_shared_mem) { + const dim_t lda = arg->lda; + const dim_t ldb = arg->ldb; + const dim_t strideAm = (arg->transa == no_trans) ? 1 : lda; + const dim_t strideAn = (arg->transa != no_trans) ? 1 : lda; + const dim_t strideBm = (arg->transb == no_trans) ? 1 : ldb; + + constexpr bool is_int8 = utils::one_of( + data_traits_t::data_type, data_type::s8, data_type::u8); + const std::shared_ptr &a_packed = arg->a_packed; + + // Padding along M, K dimensions. + dim_t m_padd = get_m_padd_parallel_a(ithr, m, arg, nthrs); + dim_t k_padd = get_k_padd(ithr, k, arg); + + size_t a_buf_nelems = m_padd * k_padd; + + // Allocate shared memory for A and its row sum buffers in master thread. + char *mem = nullptr; + a_type *bufferA = nullptr; + c_type *a_row_sum = nullptr; + + if (!a_packed) { + if (ithr == 0) { // If thread master + size_t mem_size = (a_buf_nelems * sizeof(*a) + PAGE_4K); + + if (is_int8) { + size_t a_row_sum_nelems = m_padd; + mem_size += a_row_sum_nelems * sizeof(*c) + PAGE_4K; + } + + *p_shared_mem = (char *)malloc(mem_size, 128); + } + + dnnl_thr_barrier(); + + mem = *p_shared_mem; + bufferA = (a_type *)align(mem, PAGE_4K); + + if (is_int8) + a_row_sum = (c_type *)align(bufferA + a_buf_nelems, PAGE_4K); + + if (!mem) return dnnl_out_of_memory; + } + + dnnl_status_t result = dnnl_success; // Return status + + dim_t sizeK = 0; + dim_t blk_k = 0; + for (dim_t Bk = 0; Bk < k; Bk += sizeK, blk_k++) { + sizeK = k - Bk; + if (sizeK > k_padd) sizeK = k_padd; + + // Scale C blocks by beta only for the first term of partial sum. + auto beta_eff = (Bk == 0) ? beta : 1.0f; + + // Apply C offset for the last k-block of the partial sum. + auto offsetc_eff = offset_type::none; + if (Bk + sizeK == k) offsetc_eff = offsetc; + + dim_t sizeM = 0; + for (dim_t Bm = 0; Bm < m; Bm += sizeM) { + sizeM = m - Bm; + if (sizeM > m_padd) sizeM = m_padd; + + if ((ithr < nthrs) && !a_packed) { + dim_t band = (sizeM + nthrs - 1) / nthrs; + band = utils::rnd_up(band, arg->um); + + dim_t offset = band * ithr; + + // If offset is too large don't use that thread for copying. + if (offset >= sizeM) { + offset = 0; + band = 0; + } + + // Handle the tail of the copy. + if (offset + band > sizeM) { band = sizeM - offset; } + + if (band > 0) { + const a_type *a_block + = a + (Bm + offset) * strideAm + Bk * strideAn; + + dim_t buf_shift = 0; + buf_shift = offset * ((sizeK + 3) & ~3); + + /* Row sum argument is ignored for non-integer kernels and + * scaling factor is ignored by 8-bit and 16-bit copy + * kernels. + */ + c_type *a_row_sum_eff + = a_row_sum ? a_row_sum + offset : nullptr; + if (arg->transa) { + pack_N16_8bit_V2_lxvp<__vector signed char>(sizeK, band, + a_block, lda, bufferA + buf_shift, + a_row_sum_eff); + } else { + pack_T16_8bit_V2<__vector signed char>(sizeK, band, + a_block, lda, bufferA + buf_shift, + a_row_sum_eff); + } + } + } + if (!a_packed) + dnnl_thr_barrier(); // Wait for finishing parallel copy. + + const b_type *b_block = b + Bk * strideBm; + c_type *c_block = c + Bm; + + dim_t co_stride = 0; + if (offsetc_eff == offset_type::fixed) { + co_stride = 0; + } else if (offsetc_eff == offset_type::row) { + co_stride = 0; + } else if (offsetc_eff == offset_type::column) { + co_stride = Bm; + } + + auto bufferA_eff + = a_packed ? a_packed->matrix(0, Bm, Bk) : bufferA; + auto a_row_sum_eff = a_packed + ? a_packed->row_sums(0, Bm, blk_k) + : a_row_sum; + + auto this_result = kernel_driver_parallel_acopiedbcopy(ithr, sizeM, + n, sizeK, blk_k, Bk, bufferA_eff, b_block, beta_eff, + c_block, offsetc_eff, co + co_stride, a_row_sum_eff, arg); + + if (this_result != dnnl_success) result = this_result; + + if (!a_packed) + dnnl_thr_barrier(); // Wait for kernel computations to finish. + } + } + // Free memory allocated in master thread + if (ithr == 0 && !a_packed) free(mem); + + return result; +} + +template +static inline void adjust_thread_count(dim_t m, dim_t n, dim_t k, int *nthrs) { + + const double omp_overhead_small_core = 3.0e+3; + const double omp_intercept_big_core = 4.0e+3; + const double omp_slope_big_core = 5.0e+2; + + auto veclen = get_vector_length(); + const double fp_per_cycle = 2.0 * 2.0 * veclen; + double gemm_cycles = m * n * k / fp_per_cycle; + gemm_cycles *= 8.0; + + int i = *nthrs; + + // Use a different model for omp overheads if nthrs is <= 4 + if (*nthrs <= 4 && omp_overhead_small_core > 0) { + double omp_cycles = omp_overhead_small_core; + if (gemm_cycles < omp_cycles) { + *nthrs = 1; + return; + } else { + while (i > 1) { + if (omp_cycles * i < gemm_cycles * (i - 1)) break; + --i; + } + } + } else { + if (gemm_cycles < (omp_intercept_big_core + 2 * omp_slope_big_core)) { + *nthrs = 1; + return; + } + + // adaptive decrement to march fasterĀ· + while (i > 1) { + double omp_cycles = omp_intercept_big_core + i * omp_slope_big_core; + if (omp_cycles * i < gemm_cycles * (i - 1)) break; + + if (i < 10) + i -= 2; + else if (i < 30) + i -= 4; + else + i -= 8; + } + } + + if (i < 1) i = 1; + + *nthrs = i; +} + +template +static dnnl_status_t gemm_threading_driver( + gemm_info_t *arg) { + + auto packing = (arg->packing != pack_type::none); + auto is_a_packed = (arg->transa == packed); + auto is_b_packed = (arg->transb == packed); + constexpr bool is_int8 = utils::one_of( + data_traits_t::data_type, data_type::s8, data_type::u8); + + if ((arg->m <= 0) || (arg->n <= 0)) return dnnl_success; + + if (is_a_packed && arg->bo != 0) + if (!arg->a_packed->has_row_sums()) return dnnl_invalid_arguments; + + if (is_b_packed && arg->ao != 0) + if (!arg->b_packed->has_col_sums()) return dnnl_invalid_arguments; + + auto nthr_max = dnnl_get_current_num_threads(); + int nthr_goal = nthr_max; + + adjust_thread_count(arg->m, arg->n, arg->k, &nthr_goal); + + const gemm_threading_t *force_threading = nullptr; + gemm_threading_t force_k_decomp; + + // Initialize per-thread data. + // Note: to support k blocking with non-packed GEMM, threading must be + // chosen now and force_threading set. + if (!packing) { + // Override choice of thread count if data is pre-packed for a particular + // number of threads. + if (is_a_packed && is_b_packed) + if (arg->a_packed->threading() != arg->b_packed->threading()) + return dnnl_invalid_arguments; + if (is_a_packed) + force_threading = &arg->a_packed->threading(); + else if (is_b_packed) + force_threading = &arg->b_packed->threading(); + else if (arg->n <= 128 && arg->k >= 3072 && is_int8) { + // Try k-partitioning. + set_thread_opts_pack(nthr_goal, force_k_decomp, arg); + + // Decide partition type later if no partitions in k-dimension. + if (force_k_decomp.nthrs_k > 1) force_threading = &force_k_decomp; + } + + if (force_threading) { + nthr_goal = force_threading->nthrs(); + arg->update_blocking(*force_threading); + } + } else { + // Prepare packed data layout. + gemm_pack_storage_t *pack_dst = arg->pack_dst; + bool do_a = (arg->packing == pack_type::pack_a); + + pack_dst->which() = do_a ? matrix_id::a : matrix_id::b; + pack_dst->setup(nthr_goal, do_a && is_int8, !do_a && is_int8); + + auto &thread_info = pack_dst->threading(); + force_threading = &thread_info; + + nthr_goal = set_thread_opts(nthr_goal, nthr_max, thread_info, arg); + arg->update_blocking(thread_info); + + if (thread_info.copy != copy_type::no_copy) { + for (int ithr = 0; ithr < nthr_goal; ithr++) { + if (!pack_dst->is_first_thread_in_slice(ithr)) continue; + + auto slice = thread_info.get_thread_slice( + ithr, arg->m, arg->n, arg->k); + + auto m = slice.m, n = slice.n, k = slice.k; + + auto m_padd = (thread_info.copy == copy_type::shared_a) + ? get_m_padd_parallel_a( + ithr, m, arg, thread_info.nthrs()) + : get_m_padd(ithr, m, arg); + auto n_padd = get_n_padd(ithr, n, k, arg); + auto k_padd = get_k_padd(ithr, k, arg); + + do_a ? pack_dst->set_blocking(ithr, m, k, m_padd, k_padd) + : pack_dst->set_blocking(ithr, k, n, k_padd, n_padd); + } + } else { + auto ld = do_a ? gemm_utils::get_ld_padd(arg->m) + : gemm_utils::get_ld_padd(arg->k); + + pack_dst->set_nocopy(0, no_trans, ld, do_a ? arg->k : arg->n); + } + + do_a ? pack_dst->finalize() + : pack_dst->finalize(); + + if (arg->measure_only) return dnnl_success; + } + + // This needs to see whether we need this function or not. + if (nthr_goal == 1) + return gemm_kernel_driver(0, arg->m, arg->n, arg->k, arg->a, arg->b, + arg->beta, arg->c, arg->ldc, arg->offsetc, arg->co, arg); + + bool k_blocking = force_threading && (force_threading->nthrs_k > 1); + bool k_summing = k_blocking && !packing; + + auto *thread_arg = (gemm_per_thread_t *)malloc( + sizeof(gemm_per_thread_t) * nthr_max, PAGE_4K); + + if (!thread_arg) return dnnl_out_of_memory; + + dim_t max_mt = 0, max_nt = 0; + for (int ithr = 0; ithr < nthr_max; ithr++) { + thread_arg[ithr].result = dnnl_success; + thread_arg[ithr].compute_done = false; + thread_arg[ithr].c_local = nullptr; + thread_arg[ithr].c_global = nullptr; + thread_arg[ithr].ldc_global = arg->ldc; + thread_arg[ithr].ldc_local = 0; + + if (force_threading) { + thread_arg[ithr].slice = force_threading->get_thread_slice( + ithr, arg->m, arg->n, arg->k); + thread_arg[ithr].nthr_k = force_threading->nthrs_k; + thread_arg[ithr].thr_k_stride = force_threading->thr_k_stride(); + max_mt = nstl::max(max_mt, thread_arg[ithr].slice.m); + max_nt = nstl::max(max_nt, thread_arg[ithr].slice.n); + } else { + thread_arg[ithr].slice = {0, 0, 0, 0, 0, 0, 0, 0, 0}; + thread_arg[ithr].nthr_k = 1; + thread_arg[ithr].thr_k_stride = 0; + } + } + + // Create temporary C buffers for k blocking if needed. + c_type *c_local_storage = nullptr; + if (k_summing) { + const dim_t BAD_LD_MULT = 256; + dim_t ldc_local = max_mt % BAD_LD_MULT + ? max_mt + : gemm_utils::get_ld_padd(max_mt); + dim_t c_local_stride = ldc_local * max_nt; + c_local_storage = (c_type *)malloc( + sizeof(c_type) * c_local_stride * nthr_goal, PAGE_4K); + + if (!c_local_storage) { + free(thread_arg); + return dnnl_out_of_memory; + } + + for (int ithr = 0; ithr < nthr_goal; ithr++) { + thread_arg[ithr].c_local = c_local_storage + ithr * c_local_stride; + thread_arg[ithr].ldc_local = ldc_local; + } + } + + char *shared_mem = nullptr; + + // Always use the maximum number of threads to avoid OMP overhead that can + // occur due to change thread counts. + int nthr_spawn = dnnl_thr_syncable() ? nthr_max : nthr_goal; + + parallel(nthr_spawn, [&](int ithr, int nthr) { + int nthr_eff = force_threading ? nthr_goal : nstl::min(nthr_goal, nthr); + + if (nthr_eff == 1) { + thread_arg[0].result = gemm_kernel_driver(0, arg->m, arg->n, arg->k, + arg->a, arg->b, arg->beta, arg->c, arg->ldc, arg->offsetc, + arg->co, arg); + } else { + gemm_threading_t thread_info; + + if (force_threading) + thread_info = *force_threading; + else { + nthr_eff = set_thread_opts(nthr_eff, nthr, thread_info, arg); + if (ithr < nthr_eff) + thread_arg[ithr].slice = thread_info.get_thread_slice( + ithr, arg->m, arg->n, arg->k); + } + + for (; ithr < nthr_eff; ithr += nthr) { + // Get submatrices and parameters for this thread's GEMM. + const a_type *a = nullptr; + const b_type *b = nullptr; + c_type *c = nullptr; + const c_type *co = nullptr; + std::tie(a, b, c, co) + = decompose_matrices(thread_arg[ithr].slice, arg); + + auto m = thread_arg[ithr].slice.m; + auto n = thread_arg[ithr].slice.n; + auto k = thread_arg[ithr].slice.k; + thread_arg[ithr].c_global = c; + auto c_eff = c; + auto ldc_eff = arg->ldc; + auto beta_eff = arg->beta; + auto offsetc_eff = arg->offsetc; + + // For all but first k block: substitute local C matrix and + // disable postops. + if (k_summing && thread_arg[ithr].slice.ithr_k > 0) { + c_eff = thread_arg[ithr].c_local; + ldc_eff = thread_arg[ithr].ldc_local; + beta_eff = 0; + offsetc_eff = offset_type::none; + } + + // Dispatch appropriate GEMM driver. + switch (thread_info.copy) { + case copy_type::shared_a: + + thread_arg[ithr].result = parallel_a_copy(ithr, + nthr_eff, m, n, k, a, b, beta_eff, c_eff, + ldc_eff, offsetc_eff, co, arg, &shared_mem); + break; + + default: + case copy_type::nonshared: + thread_arg[ithr].result = gemm_kernel_driver(ithr, m, n, + k, a, b, beta_eff, c_eff, ldc_eff, offsetc_eff, + co, arg); + break; + } + // Sum thread results along k dimension, parallelized in the n + // dimension. To avoid deadlocks, results are summed later if + // not all threads are running concurrently. We can only detect + // if this is safe when using OpenMP. +#if DNNL_THR_SYNC == 1 + if (k_summing && (nthr >= nthr_eff)) { + thread_arg[ithr].compute_done = true; + sum_k_blocks(ithr, thread_arg, true); + } +#endif + } + } + }); + + dnnl_status_t result = dnnl_success; // Initialize to success + for (int ithr = 0; ithr < nthr_max; ithr++) { + if (thread_arg[ithr].result != dnnl_success) { + result = static_cast(thread_arg[ithr].result); + break; + } + } + // Sum thread results along k dimension if this wasn't done earlier. + if (k_summing && !thread_arg[0].compute_done) { + parallel(nthr_goal, [&](int ithr, int nthr) { + for (; ithr < nthr_goal; ithr += nthr) + sum_k_blocks(ithr, thread_arg, false); + }); + } + + if (c_local_storage) dnnl::impl::free(c_local_storage); + dnnl::impl::free(thread_arg); + + return result; +} + +template +dnnl_status_t gemm_driver(const char *transA, const char *transB, + const char *offsetC, const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const a_type *a, const dim_t *lda, const a_type *oa, + const b_type *b, const dim_t *ldb, const b_type *ob, const float *beta, + c_type *c, const dim_t *ldc, const c_type *oc, const bool force_nocopy, + pack_type packing, gemm_pack_storage_t *pack_dst, bool measure_only) { + + constexpr bool is_int8 = utils::one_of( + data_traits_t::data_type, data_type::s8, data_type::u8); + MAYBE_UNUSED(is_int8); + + gemm_info_t args(transA, transB, offsetC, m, n, k, + alpha, a, lda, oa, b, ldb, ob, beta, c, ldc, oc, force_nocopy, + packing, pack_dst, measure_only); + + return gemm_threading_driver(&args); +} + +template // Instantiate gemm_s8s8s32 + dnnl_status_t + gemm_driver(const char *transA, + const char *transB, const char *offsetC, const dim_t *m, + const dim_t *n, const dim_t *k, const float *alpha, + const int8_t *a, const dim_t *lda, const int8_t *oa, + const int8_t *b, const dim_t *ldb, const int8_t *ob, + const float *beta, int32_t *c, const dim_t *ldc, + const int32_t *oc, const bool force_nocopy, pack_type packing, + gemm_pack_storage_t *pack_dst, bool measure_only); + +template // Instantiate gemm_s8u8s32 + dnnl_status_t + gemm_driver(const char *transA, + const char *transB, const char *offsetC, const dim_t *m, + const dim_t *n, const dim_t *k, const float *alpha, + const int8_t *a, const dim_t *lda, const int8_t *oa, + const uint8_t *b, const dim_t *ldb, const uint8_t *ob, + const float *beta, int32_t *c, const dim_t *ldc, + const int32_t *oc, const bool force_nocopy, pack_type packing, + gemm_pack_storage_t *pack_dst, bool measure_only); +#undef MAX_STACK_SZ +} // namespace ppc64 +} // namespace cpu +} // namespace impl +} // namespace dnnl +#endif diff --git a/src/cpu/ppc64/gemm/gemm_driver.hpp b/src/cpu/ppc64/gemm/gemm_driver.hpp new file mode 100644 index 00000000000..777606902bd --- /dev/null +++ b/src/cpu/ppc64/gemm/gemm_driver.hpp @@ -0,0 +1,51 @@ +/******************************************************************************* +* Copyright 2022 IBM Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_PPC64_GEMM_GEMM_DRIVER_HPP +#define CPU_PPC64_GEMM_GEMM_DRIVER_HPP + +#include "common/c_types_map.hpp" +#include "common/dnnl_thread.hpp" +#include "oneapi/dnnl/dnnl_types.h" + +#include "cpu/ppc64/gemm/gemm_info.hpp" +#include "cpu/ppc64/gemm/gemm_pack_storage.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace ppc64 { + +template +dnnl_status_t gemm_driver(const char *transA, const char *transB, + const char *offsetC, const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const a_type *a, const dim_t *lda, const a_type *oa, + const b_type *b, const dim_t *ldb, const b_type *ob, const float *beta, + c_type *c, const dim_t *ldc, const c_type *oc, + const bool force_jit_nocopy_gemm, pack_type packing = pack_type::none, + gemm_pack_storage_t *pack_dst = NULL, bool measure_only = false); + +void prep_ref_gemm_s8u8s32_pack( + bool do_a, dim_t rows, dim_t cols, gemm_pack_storage_t *pack_dst); + +dnnl_status_t ref_gemm_s8u8s32_pack(const void *src, dim_t ld_src, dim_t rows, + dim_t cols, int trans, gemm_pack_storage_t *dst_pack); + +} // namespace ppc64 +} // namespace cpu +} // namespace impl +} // namespace dnnl +#endif diff --git a/src/cpu/ppc64/gemm/gemm_info.cpp b/src/cpu/ppc64/gemm/gemm_info.cpp new file mode 100644 index 00000000000..1678fe44098 --- /dev/null +++ b/src/cpu/ppc64/gemm/gemm_info.cpp @@ -0,0 +1,170 @@ +/******************************************************************************* +* Copyright 2022 IBM Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include +#include + +#include "common/dnnl_traits.hpp" +#include "cpu/gemm/gemm.hpp" +#include "cpu/ppc64/gemm/gemm_info.hpp" +#include "oneapi/dnnl/dnnl_types.h" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace ppc64 { + +static inline int decode_trans(char trans) { + switch (trans) { + case 'T': + case 't': return do_trans; + case 'P': + case 'p': return packed; + default: return no_trans; + } +} + +namespace { +template // XXX for float and bfloat +void prepare_bo(int32_t &bo_gemm_info, const b_t *bo_orig) { + UNUSED(bo_orig); + bo_gemm_info = 0; +} +template <> +void prepare_bo(int32_t &bo_gemm_info, const uint8_t *bo_orig) { + bo_gemm_info = bo_orig ? *bo_orig : 0; +} +template <> +void prepare_bo(int32_t &bo_gemm_info, const int8_t *bo_orig) { + int bo_s32 = bo_orig ? *bo_orig : 0; + bo_s32 += 128; + bo_gemm_info = bo_s32; +} + +} // namespace + +template +gemm_info_t::gemm_info_t(const char *transA, const char *transB, + const char *offsetC, const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const a_t *a, const dim_t *lda, const a_t *oa, + const b_t *b, const dim_t *ldb, const b_t *ob, const float *beta, + c_t *c, const dim_t *ldc, const c_t *oc, bool force_nocopy, + pack_type packing, gemm_pack_storage_t *pack_dst, bool measure_only) { + + this->transa = decode_trans(*transA); + this->transb = decode_trans(*transB); + + this->m = *m; + this->n = *n; + this->k = *k; + + this->a = a; + this->b = b; + this->c = c; + + this->lda = lda ? *lda : 0; + this->ldb = ldb ? *ldb : 0; + this->ldc = ldc ? *ldc : 0; + + this->ao = 0; + this->bo = 0; + this->co = nullptr; + + this->alpha = alpha ? *alpha : 1.0f; + this->beta = beta ? *beta : 1.0f; + + this->offsetc = offset_type::none; + + this->packing = packing; + this->pack_dst = pack_dst; + this->measure_only + = measure_only && pack_dst && (packing != pack_type::none); + + if (this->transa == packed) { + dim_t cols; + + this->a_packed.reset(new gemm_pack_storage_t(a)); + if (this->a_packed->get_nocopy(this->transa, this->lda, cols)) { + this->a = this->a_packed->template matrix(); + this->a_packed = nullptr; + } + } + + if (this->transb == packed) { + dim_t rows; + + this->b_packed.reset(new gemm_pack_storage_t(b)); + if (this->b_packed->get_nocopy(this->transb, this->ldb, rows)) { + this->b = this->b_packed->template matrix(); + this->b_packed = nullptr; + } + } + + constexpr bool is_int8 = utils::one_of( + data_traits_t::data_type, data_type::s8, data_type::u8); + if (is_int8) this->ao = oa ? *oa : a_t(0); + prepare_bo(this->bo, ob); + + this->b_is_signed = false; + + if (data_traits_t::data_type == data_type::s8) + this->b_is_signed = true; + + if (offsetC != nullptr) { + char offsetc = *offsetC; + if (offsetc == 'F' || offsetc == 'f') { + this->offsetc = offset_type::fixed; + } else if (offsetc == 'R' || offsetc == 'r') { + this->offsetc = offset_type::row; + } else { // offsetc == 'C' || offsetc == 'c' + this->offsetc = offset_type::column; + } + this->co = oc; + } + + // Blocking of M, N and K + this->um = 16; + this->un = 4; + this->uk = 1; + this->bm = 4096; + this->bn = 128; + this->bk = 128; + this->bk_traditional = 128; + this->blocking_small_k = 64; + this->bn_small_k = 16; +} + +template +void gemm_info_t::update_blocking( + const gemm_threading_t &thread_info) { + + if (thread_info.block_m > 0) this->bm = thread_info.block_m; + if (thread_info.block_n > 0) this->bn = thread_info.block_n; + if (thread_info.block_k > 0) this->bk = thread_info.block_k; +} + +// Instantiate the gemm_info_t templates needed. +template // For gemm_s8u8s32 + struct gemm_info_t; + +template // For gemm_s8s8s32 + struct gemm_info_t; + +} // namespace ppc64 +} // namespace cpu +} // namespace impl +} // namespace dnnl diff --git a/src/cpu/ppc64/gemm/gemm_info.hpp b/src/cpu/ppc64/gemm/gemm_info.hpp new file mode 100644 index 00000000000..16eec47426d --- /dev/null +++ b/src/cpu/ppc64/gemm/gemm_info.hpp @@ -0,0 +1,90 @@ +/******************************************************************************* +* Copyright 2022 IBM Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_PPC64_GEMM_GEMM_INFO_HPP +#define CPU_PPC64_GEMM_GEMM_INFO_HPP + +#include +#include + +#include "common/c_types_map.hpp" +#include "cpu/ppc64/gemm/gemm_pack_storage.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace ppc64 { + +enum class pack_type { none, pack_a, pack_b }; + +enum class offset_type { + none, + fixed, + column, + row, +}; + +// Indices for kernel arrays. TODO Is it okay to place this here? +enum { no_sum = 0, do_sum = 1 }; +enum { no_trans = 0, do_trans = 1, packed = 2 }; +enum { no_beta0 = 0, do_beta0 = 1 }; +enum { no_alpha1 = 0, do_alpha1 = 1 }; + +template +struct gemm_info_t { + + // Interface arguments. + int transa, transb; + offset_type offsetc; + dim_t m, n, k; + dim_t lda, ldb, ldc; + const a_t *a; + const b_t *b; + c_t *c; + float alpha, beta; + + bool b_is_signed; + int32_t ao; + int32_t bo; + const c_t *co; + + pack_type packing; + gemm_pack_storage_t *pack_dst; + bool measure_only; + std::shared_ptr a_packed, b_packed; + + // Kernel parameters. + dim_t um, un, uk, bm, bn, bk; + dim_t bn_small_k, bk_traditional, blocking_small_k; + + // Gemv parameters + int swap = false; + gemm_info_t(const char *transA, const char *transB, const char *offsetC, + const dim_t *m, const dim_t *n, const dim_t *k, const float *alpha, + const a_t *a, const dim_t *lda, const a_t *oa, const b_t *b, + const dim_t *ldb, const b_t *ob, const float *beta, c_t *c, + const dim_t *ldc, const c_t *oc, bool force_nocopy, + pack_type packing, gemm_pack_storage_t *pack_dst, + bool measure_only); + + void update_blocking(const gemm_threading_t &thread_info); +}; +} // namespace ppc64 +} // namespace cpu +} // namespace impl +} // namespace dnnl + +#endif // CPU_PPC64_GEMM_GEMM_INFO_HPP diff --git a/src/cpu/ppc64/gemm/gemm_pack_storage.hpp b/src/cpu/ppc64/gemm/gemm_pack_storage.hpp new file mode 100644 index 00000000000..67a9ca3c807 --- /dev/null +++ b/src/cpu/ppc64/gemm/gemm_pack_storage.hpp @@ -0,0 +1,395 @@ +/******************************************************************************* +* Copyright 2022 IBM Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_PPC64_GEMM_GEMM_PACK_STORAGE_HPP +#define CPU_PPC64_GEMM_GEMM_PACK_STORAGE_HPP + +#include + +#include "common/dnnl_thread.hpp" +#include "common/utils.hpp" + +#include "cpu/ppc64/gemm/gemm_threading.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace ppc64 { + +enum struct matrix_id { a, b }; + +struct gemm_pack_storage_t { + gemm_threading_t &threading() { return header->threading; } + matrix_id &which() { return header->which; } + bool &has_row_sums() { return header->has_row_sums; } + bool &has_col_sums() { return header->has_col_sums; } + + const gemm_threading_t &threading() const { return header->threading; } + const matrix_id &which() const { return header->which; } + const bool &has_row_sums() const { return header->has_row_sums; } + const bool &has_col_sums() const { return header->has_col_sums; } + + size_t size() const { return header->size; } + void *get() const { return static_cast(base); } + void set(void *data) { + base = static_cast(data); + header = static_cast(data); + } + + bool single_nocopy() const { + return (threading().copy == copy_type::no_copy); + } + + int nthr() const { return single_nocopy() ? 1 : threading().nthrs(); } + + int nslice() const { + return (which() == matrix_id::a) + ? threading().nthrs_m * threading().nthrs_k + : threading().nthrs_n * threading().nthrs_k; + } + + template + gemm_pack_storage_t(data_type *data_, bool header_set_ = true) + : base(nullptr) + , header(nullptr) + , matrix_header(nullptr) + , sums_header(nullptr) + , header_set(header_set_) { + reset((void *)data_); + } + + gemm_pack_storage_t() + : base(nullptr) + , header(nullptr) + , matrix_header(nullptr) + , sums_header(nullptr) + , header_set(true) {} + + std::tuple thread_slice_info(int ithr) const { + assert(ithr < nthr()); + + bool is_a = (which() == matrix_id::a); + auto nthr_inner = is_a ? threading().nthrs_m : threading().nthrs_n; + + auto ithr_i = ithr % threading().nthrs_m; + auto ithr_jk = ithr / threading().nthrs_m; + auto ithr_j = ithr_jk % threading().nthrs_n; + auto ithr_k = ithr_jk / threading().nthrs_n; + + auto ithr_inner = is_a ? ithr_i : ithr_j; + auto ithr_outer = ithr_k; + auto ithr_slice = is_a ? ithr_j : ithr_i; + + auto id = ithr_outer * nthr_inner + ithr_inner; + + return std::make_tuple(id, ithr_slice); + } + + int thread_to_slice(int ithr) const { + return std::get<0>(thread_slice_info(ithr)); + } + + bool is_first_thread_in_slice(int ithr) const { + return (std::get<1>(thread_slice_info(ithr)) == 0); + } + + template + data_type *row_sums(int ithr, dim_t r0, dim_t cblock) const { + if (!has_row_sums()) return NULL; + auto id = thread_to_slice(ithr); + return get_block(sums_header->slice[id], r0, cblock); + } + + template + data_type *col_sums(int ithr, dim_t rblock, dim_t c0) const { + if (!has_col_sums()) return NULL; + auto id = thread_to_slice(ithr); + return get_block(sums_header->slice[id], rblock, c0); + } + + template + data_type *matrix(int ithr, dim_t r0, dim_t c0) const { + auto id = thread_to_slice(ithr); + return get_block(matrix_header->slice[id], r0, c0); + } + + template + data_type *matrix(int ithr) const { + assert(!matrix_header->slice[thread_to_slice(ithr)].packed); + return matrix(ithr, 0, 0); + } + + template + data_type *matrix() const { + assert(single_nocopy()); + return matrix(0); + } + + bool get_nocopy(int ithr, int &trans, dim_t &ld, dim_t &td) const { + auto id = thread_to_slice(ithr); + return matrix_header->slice[id].get_nocopy(trans, ld, td); + } + + bool get_nocopy(int &trans, dim_t &ld, dim_t &td) const { + if (!single_nocopy()) return false; + return get_nocopy(0, trans, ld, td); + } + + void get_blocking(int ithr, dim_t &block_r, dim_t &block_c) const { + auto id = thread_to_slice(ithr); + matrix_header->slice[id].get_blocking(block_r, block_c); + } + + void set_blocking( + int ithr, dim_t rows, dim_t cols, dim_t block_r, dim_t block_c) { + + auto id = thread_to_slice(ithr); + auto nblk_r = (block_r == 0) ? 0 : utils::div_up(rows, block_r); + auto nblk_c = (block_c == 0) ? 0 : utils::div_up(cols, block_c); + + matrix_header->slice[id].set_blocking(nblk_r, nblk_c, block_r, block_c); + + if (has_row_sums()) + sums_header->slice[id].set_blocking(nblk_r, nblk_c, block_r, 1); + else + sums_header->slice[id].set_blocking(nblk_r, nblk_c, 1, block_c); + } + + void set_nocopy(int ithr, int trans, dim_t ld, dim_t td) { + auto id = thread_to_slice(ithr); + matrix_header->slice[id].set_nocopy(trans, ld, td); + } + + void setup(int max_nthr, bool has_row_sums = false, + bool has_col_sums = false) { + + assert(!(has_row_sums && has_col_sums)); + + auto sz_mh = matrix_header_size(max_nthr); + auto sz_h = header_size(); + + header->has_row_sums = has_row_sums; + header->has_col_sums = has_col_sums; + header->off_matrix = sz_h; + header->off_sums = sz_h + sz_mh; + total_header_size = sz_h + sz_mh * 2; + + header->size = 0; + + header_set = true; + + reset(get()); + + for (int id = 0; id < max_nthr; id++) { + matrix_header->slice[id].set_blocking(0, 0, 0, 0); + sums_header->slice[id].set_blocking(0, 0, 0, 0); + } + } + + template + void finalize() { + assert(total_header_size > 0); + size_t cur_off = total_header_size; + + matrix_header->finalize(cur_off, nslice()); + if (has_row_sums() || has_col_sums()) + sums_header->finalize(cur_off, nslice()); + + header->size = cur_off; + + /* Compute kernels overrun to preload data. */ + header->size += align_data; + } + +protected: + char *base; + + struct header_t { + matrix_id which; + bool has_row_sums; + bool has_col_sums; + size_t off_matrix, off_sums; + size_t size; + gemm_threading_t threading; /* if packed */ + } * header; + + struct slice_header_t { + bool packed; + int trans; + dim_t nblk_r, nblk_c; + dim_t block_r, block_c; + size_t off_data; + + template + size_t block_size() const { + return utils::rnd_up( + block_r * block_c * sizeof(data_type), align_data); + } + + template + size_t block_offset(dim_t r0, dim_t c0, bool col_major) const { + assert((r0 % block_r) == 0); + assert((c0 % block_c) == 0); + + auto rb = r0 / block_r; + auto cb = c0 / block_c; + auto mb = col_major ? rb + cb * nblk_r : cb + rb * nblk_c; + + return block_size() * mb; + } + + template + size_t size() const { + return block_size() * nblk_r * nblk_c; + } + + void set_blocking( + dim_t nblk_r_, dim_t nblk_c_, dim_t block_r_, dim_t block_c_) { + packed = true; + nblk_r = nblk_r_; + nblk_c = nblk_c_; + block_r = block_r_; + block_c = block_c_; + } + + void set_nocopy(int trans_, dim_t ld, dim_t td) { + packed = false; + trans = trans_; + block_r = ld; + block_c = td; + nblk_r = 1; + nblk_c = 1; + } + + void get_blocking(dim_t &block_r_, dim_t &block_c_) const { + block_r_ = block_r; + block_c_ = block_c; + } + + bool get_nocopy(int &trans_, dim_t &ld, dim_t &td) const { + if (!packed) { + trans_ = trans; + ld = block_r; + td = block_c; + } + return !packed; + } + + template + void finalize(size_t &cur_off) { + cur_off = utils::rnd_up(cur_off, align_data); + off_data = cur_off; + cur_off += size(); + } + }; + + struct matrix_header_t { + dim_t ld; /* if not packed */ + slice_header_t slice[1]; /* array of size nthr, if packed */ + + template + void finalize(size_t &cur_off, int nslices) { +#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL + // This, I hope, is a temporary workaround... + // The reason for this special case is that in case of threadpool + // threading this function may be called to estimate the amount of + // memory needed when no threading information is actually + // available. Hence, it needs to provide an upper bound. + size_t max_off = cur_off; + for (int id = 0; id < nslices; id++) { + slice[id].finalize(cur_off); + if (id == 0) { + // Assume that slice[0] is the largest one. + size_t slice0_size = cur_off - max_off; + max_off += slice0_size * dnnl_get_max_threads(); + } + } + if (!threadpool_utils::get_active_threadpool() && nslices) + // The std::max is a paranoid check for the case when slice[0] + // is not actually the largest one. Probably a crash will + // happen anyways... + cur_off = std::max(cur_off, max_off); +#else + for (int id = 0; id < nslices; id++) + slice[id].finalize(cur_off); +#endif + } + } * matrix_header, *sums_header; + + size_t total_header_size = 0; + + static constexpr auto align_headers = 0x20; + static constexpr auto align_data = 0x1000; + + static size_t header_size() { + return utils::rnd_up(sizeof(header_t), align_headers); + } + + static size_t matrix_header_size(int max_nthr) { + auto sz = sizeof(matrix_header_t) + + sizeof(slice_header_t) * (max_nthr - 1); + + return utils::rnd_up(sz, align_headers); + } + + template + data_type *get_block( + const slice_header_t &slice, dim_t r0, dim_t c0) const { + return reinterpret_cast(base + slice.off_data + + slice.block_offset(r0, c0, col_major())); + } + + bool col_major() const { return (which() == matrix_id::a); } + + void reset(void *data) { + set(data); + + if (!header_set) return; + + matrix_header = reinterpret_cast( + base + header->off_matrix); + sums_header + = reinterpret_cast(base + header->off_sums); + } + + bool header_set = true; +}; + +struct gemm_pack_storage_shell_t : public gemm_pack_storage_t { + + gemm_pack_storage_shell_t(int max_nthr, bool has_row_sums = false, + bool has_col_sums = false) { + void *ptr = malloc(shell_size(max_nthr), 64); + if (ptr) { + set(ptr); + setup(max_nthr, has_row_sums, has_col_sums); + } + } + + ~gemm_pack_storage_shell_t() { free(get()); } + +private: + static size_t shell_size(int max_nthr) { + return header_size() + matrix_header_size(max_nthr) * 2; + } +}; + +} // namespace ppc64 +} // namespace cpu +} // namespace impl +} // namespace dnnl + +#endif diff --git a/src/cpu/ppc64/gemm/gemm_partition.hpp b/src/cpu/ppc64/gemm/gemm_partition.hpp new file mode 100644 index 00000000000..be796c97af2 --- /dev/null +++ b/src/cpu/ppc64/gemm/gemm_partition.hpp @@ -0,0 +1,283 @@ +/******************************************************************************* +* Copyright 2022 IBM Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_PPC64_GEMM_GEMM_PARTITION_HPP +#define CPU_PPC64_GEMM_GEMM_PARTITION_HPP + +#include +#include +#include + +#include "common/nstl.hpp" +#include "common/utils.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace ppc64 { + +static inline void partition_1d(const int ithr, const int nthrs, const dim_t n, + dim_t &t_offset, dim_t &t_block) { + + dim_t band = n / nthrs; + + dim_t tail = n - (nthrs - 1) * band; + if (tail > (band + 1)) band++; + tail = n - (nthrs - 1) * band; + + if (ithr < (nthrs - 1)) + t_block = band; + else + t_block = tail; + + t_offset = ithr * band; + + if (t_offset >= n) { + t_block = 0; + t_offset = 0; + } else if ((t_offset + t_block) > n) { + t_block = n - t_offset; + } +} + +static inline void partition_2d(const int ithr, int *nthrs, const int ithr_i, + const int ithr_j, const int nthrs_m, const int nthrs_n, const dim_t m, + const dim_t n, dim_t &out_m_disp, dim_t &out_m_band, dim_t &out_n_disp, + dim_t &out_n_band) { + + dim_t m_disp = 0, n_disp = 0; + dim_t m_band = 0, n_band = 0; + + int m_div = nthrs_m; + int n_div = nthrs_n; + + dim_t m_bandt = m / m_div; /* size per thread */ + dim_t n_bandt = n / n_div; /* size per thread */ + int first_m_group = m_div - 1; + int first_n_group = n_div - 1; + dim_t first_m_val = m_bandt; + dim_t first_n_val = n_bandt; + + int mthr_used = m_div; + if (m - (m_div - 1) * m_bandt > m_bandt + 1) { + if (m - (m_div - 1) * m_bandt > m_div) ++m_bandt; + + first_m_val = m_bandt + 1; + mthr_used = (int)(m / first_m_val); + + if (mthr_used * first_m_val < m) ++mthr_used; + + first_m_group = mthr_used - 1; + } + + int nthr_used = n_div; + if (n - (n_div - 1) * n_bandt > n_bandt + 1) { + first_n_val = n_bandt + 1; + nthr_used = (int)(n / first_n_val); + + if (nthr_used * first_n_val < n) ++nthr_used; + + first_n_group = nthr_used - 1; + } + + *nthrs = mthr_used * nthr_used; + + if (ithr < *nthrs) { + if (ithr_i < first_m_group) { + m_band = first_m_val; + m_disp = ithr_i * first_m_val; + } else if (ithr_i <= mthr_used - 2) { + m_band = m_bandt; + m_disp = first_m_group * first_m_val + + (ithr_i - first_m_group) * m_bandt; + } else { + m_disp = first_m_group * first_m_val + + (mthr_used - 1 - first_m_group) * m_bandt; + m_band = nstl::max(dim_t(0), m - m_disp); + } + + if (ithr_j < first_n_group) { + n_band = first_n_val; + n_disp = ithr_j * first_n_val; + } else if (ithr_j <= nthr_used - 2) { + n_band = n_bandt; + n_disp = first_n_group * first_n_val + + (ithr_j - first_n_group) * n_bandt; + } else { + n_disp = first_n_group * first_n_val + + (nthr_used - 1 - first_n_group) * n_bandt; + n_band = nstl::max(dim_t(0), n - n_disp); + } + m_disp = nstl::max(nstl::min(m_disp, m - 1), dim_t(0)); + n_disp = nstl::max(nstl::min(n_disp, n - 1), dim_t(0)); + } + + if (ithr < *nthrs) { + out_m_disp = m_disp; + out_n_disp = n_disp; + out_m_band = m_band; + out_n_band = n_band; + } else { + out_m_disp = 0; + out_n_disp = 0; + out_m_band = 0; + out_n_band = 0; + } + + return; +} + +static inline std::tuple partition_2d_minblk_with_primes(dim_t m, + dim_t n, dim_t block_m, dim_t block_n, dim_t min_m, dim_t min_n, + dim_t um, dim_t un, int nthr, bool use_aspect_ratio) { + + auto part_m = nstl::max(dim_t(1), m / block_m); + auto part_n = nstl::max(dim_t(1), n / block_n); + + // Quick exit if there are enough partitions in one direction + // and there is only 1 partition in the other one + if (part_m == 1 && part_n >= nthr) + return std::make_tuple(1, nstl::min((int)part_n, nthr)); + + if (part_n == 1 && part_m >= nthr) + return std::make_tuple(nstl::min((int)part_m, nthr), 1); + + auto num_parts = part_m * part_n; + + int nthr_ite = nthr; + int nthr_m = 1, nthr_n = 1; + dim_t band_m = m, band_n = n; + + for (auto p : {2, 3, 5, 7, 11, 13, 17, 19, 23, 29}) { + bool finished = false; + + while ((nthr_ite % p) == 0 && !finished) { + nthr_ite /= p; + auto nthr_m_ite = nthr_m * p; + auto nthr_n_ite = nthr_n * p; + + auto band_m_ite = band_m / p; + auto band_n_ite = band_n / p; + float band_m_ite_f = static_cast(band_m_ite); + float band_n_ite_f = static_cast(band_n_ite); + + // Try partitioning with block size bm x bn + auto try_partition = [&](dim_t bm, dim_t bn, bool pick_small) { + float ratio_m = band_m_ite_f / static_cast(bm); + float ratio_n = band_n_ite_f / static_cast(bn); + bool do_m = false, do_n = false; + + if (ratio_m < 1. && ratio_n >= 1.) + do_n = true; + else if (ratio_m >= 1. && ratio_n < 1.) + do_m = true; + else if (ratio_m >= 1. && ratio_n >= 1.) { + if (use_aspect_ratio) { + float ratio_goal = static_cast(um) + / static_cast(un); + float try_ratio_m = band_m_ite_f + / static_cast(band_n) + * (1.f / ratio_goal); + float try_ratio_n = static_cast(band_m) + / band_n_ite_f * (1.f / ratio_goal); + if (pick_small) { + // Pick either the smaller or larger ratio as appropriate. + ((ratio_m < ratio_n) ? do_m : do_n) = true; + } else { + // Pick the dimension that will keep as close as possible + // to best ratio between m and n. + ((nstl::abs(try_ratio_m - 1.) + < nstl::abs(try_ratio_n - 1)) + ? do_m + : do_n) + = true; + } + } else { + (((ratio_m < ratio_n) == pick_small) ? do_m : do_n) + = true; + } + } + + if (do_m) { + // Partition m. + nthr_m = nthr_m_ite; + band_m = band_m_ite; + } else if (do_n) { + // Partition n. + nthr_n = nthr_n_ite; + band_n = band_n_ite; + } + + return do_m || do_n; + }; + + // If we will need min based partitioning do it now + if (num_parts < nthr) { + num_parts *= p; + if (try_partition(min_m, min_n, true)) continue; + } + + if (try_partition(block_m, block_n, false)) continue; + if (try_partition(min_m, min_n, true)) continue; + + // Both band_m/n are smaller than min_m/n + // exit the loops, nothing to partition + finished = true; + } + + if (finished) break; + } + + return std::make_tuple(nthr_m, nthr_n); +} + +static inline std::tuple partition_2d_minblk(dim_t m, dim_t n, + dim_t block_m, dim_t block_n, dim_t min_m, dim_t min_n, dim_t um, + dim_t un, int nthr, bool use_aspect_ratio) { + + auto part_m = nstl::max(dim_t(1), m / min_m); + auto part_n = nstl::max(dim_t(1), n / min_n); + + // Quick exit if one of the dimensions is too small to partition. + if (part_m == 1) { + part_n = nstl::max(dim_t(1), utils::div_up(n, min_n)); + return std::make_tuple(1, nstl::min((int)part_n, nthr)); + } + + if (part_n == 1) { + part_m = nstl::max(dim_t(1), utils::div_up(m, min_m)); + return std::make_tuple(nstl::min((int)part_m, nthr), 1); + } + + int nthr_m = 0, nthr_n = 0; + auto nthr_thresh = nstl::min(0.95 * nthr, (double)(part_m * part_n)); + + for (int nthr_new = nthr; nthr_new > nthr / 2; nthr_new--) { + if (nthr_m * nthr_n >= nthr_thresh) break; + std::tie(nthr_m, nthr_n) + = partition_2d_minblk_with_primes(m, n, block_m, block_n, min_m, + min_n, um, un, nthr_new, use_aspect_ratio); + } + + return std::make_tuple(nthr_m, nthr_n); +} + +} // namespace ppc64 +} // namespace cpu +} // namespace impl +} // namespace dnnl + +#endif diff --git a/src/cpu/ppc64/gemm/gemm_threading.hpp b/src/cpu/ppc64/gemm/gemm_threading.hpp new file mode 100644 index 00000000000..f698cd5dc9b --- /dev/null +++ b/src/cpu/ppc64/gemm/gemm_threading.hpp @@ -0,0 +1,124 @@ +/******************************************************************************* +* Copyright 2022 IBM Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_PPC64_GEMM_GEMM_THREADING_HPP +#define CPU_PPC64_GEMM_GEMM_THREADING_HPP + +#include + +#include "common/c_types_map.hpp" + +#include "cpu/ppc64/gemm/gemm_partition.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace ppc64 { + +enum class partition_type { row_1d, col_1d, col_major_2d, mnk_3d }; + +enum class copy_type { + nonshared, + shared_a, + no_copy, +}; + +struct gemm_slice_t { + dim_t off_m, off_n, off_k; + dim_t m, n, k; + int ithr_m, ithr_n, ithr_k; +}; + +struct gemm_threading_t { + gemm_threading_t() {}; + + int nthrs_m, nthrs_n, nthrs_k; + dim_t block_m, block_n, block_k; // Blocking sizes (-1 = default) + dim_t thread_m, thread_n, thread_k; // Thread matrix sizes (-1 = default) + partition_type partition; + copy_type copy; + + int nthrs() const { return nthrs_m * nthrs_n * nthrs_k; } + + friend bool operator==( + const gemm_threading_t &t1, const gemm_threading_t &t2) { + return (t1.nthrs_m == t2.nthrs_m && t1.nthrs_n == t2.nthrs_n + && t1.nthrs_k == t2.nthrs_k && t1.partition == t2.partition + && t1.copy == t2.copy); + } + + friend bool operator!=( + const gemm_threading_t &t1, const gemm_threading_t &t2) { + return !(t1 == t2); + } + + gemm_slice_t get_thread_slice(int ithr, dim_t m, dim_t n, dim_t k) const { + + dim_t off_m = 0, off_n = 0, off_k = 0; + dim_t size_m = m, size_n = n, size_k = k; + int ithr_m = 0, ithr_n = 0, ithr_k = 0; + + switch (partition) { + case partition_type::row_1d: + ithr_m = ithr; + partition_1d(ithr, nthrs(), m, off_m, size_m); + break; + + case partition_type::col_1d: + ithr_n = ithr; + partition_1d(ithr, nthrs(), n, off_n, size_n); + break; + + case partition_type::col_major_2d: { + int nthr_eff = nthrs(); + ithr_m = ithr % nthrs_m; + ithr_n = ithr / nthrs_m; + + partition_2d(ithr, &nthr_eff, ithr_m, ithr_n, nthrs_m, nthrs_n, + m, n, off_m, size_m, off_n, size_n); + break; + } + + case partition_type::mnk_3d: { + assert(thread_m > 0 && thread_n > 0 && thread_k > 0); + ithr_m = ithr % nthrs_m; + ithr_n = (ithr / nthrs_m) % nthrs_n; + ithr_k = (ithr / nthrs_m) / nthrs_n; + + off_m = ithr_m * thread_m; + off_n = ithr_n * thread_n; + off_k = ithr_k * thread_k; + + size_m = nstl::min(thread_m, m - off_m); + size_n = nstl::min(thread_n, n - off_n); + size_k = nstl::min(thread_k, k - off_k); + break; + } + } + + return {off_m, off_n, off_k, size_m, size_n, size_k, ithr_m, ithr_n, + ithr_k}; + } + + int thr_k_stride() const { return nthrs_m * nthrs_n; } +}; + +} // namespace ppc64 +} // namespace cpu +} // namespace impl +} // namespace dnnl + +#endif diff --git a/src/cpu/ppc64/gemm/gemm_utils.hpp b/src/cpu/ppc64/gemm/gemm_utils.hpp new file mode 100644 index 00000000000..77ed0c644b9 --- /dev/null +++ b/src/cpu/ppc64/gemm/gemm_utils.hpp @@ -0,0 +1,223 @@ +/******************************************************************************* +* Copyright 2022 IBM Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_PPC64_GEMM_GEMM_UTILS_HPP +#define CPU_PPC64_GEMM_GEMM_UTILS_HPP + +#include + +#include "common/dnnl_thread.hpp" +#include "common/dnnl_traits.hpp" +#include "common/utils.hpp" + +#include "cpu/ppc64/gemm/gemm_pack_storage.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace ppc64 { +namespace gemm_utils { + +static inline std::tuple calc_nthr_2d(int nthrs, dim_t m, dim_t n, + dim_t block_m, dim_t block_n, dim_t small_m, dim_t small_n, + dim_t &thread_m, dim_t &thread_n) { + + int nthr_m = static_cast(utils::div_up(m, block_m)); + int nthr_n = static_cast(utils::div_up(n, block_n)); + + if (nthr_m < 1) nthr_m = 1; + if (nthr_n < 1) nthr_n = 1; + + float ratio_float = static_cast(nthr_m) / static_cast(nthr_n); + + int ratio = 0; + if (nthr_m > nthr_n) + ratio = (int)ratio_float; + else + ratio = (int)(1. / ratio_float); + + // scale down nthr_m and nthr_n if they are too large + while (nthr_m * nthr_n > 4 * nthrs) { + nthr_m /= 2; + nthr_n /= 2; + } + + if (nthr_m < 1) nthr_m = 1; + if (nthr_n < 1) nthr_n = 1; + + // Simple partition reduction + int counter = 0; + while (nthr_m * nthr_n > nthrs) { + if (nthr_m > nthr_n) { + if (counter < ratio) + nthr_m--; + else { + nthr_n--; + counter = -1; + } + } else { + if (counter < ratio) + nthr_n--; + else { + nthr_m--; + counter = -1; + } + } + counter++; + } + + // Simple partition increment + counter = 0; + while (nthr_m * nthr_n < 0.95 * nthrs) { + if (nthr_m > nthr_n) { + if (counter < ratio) + nthr_m++; + else { + nthr_n++; + counter = -1; + } + } else { + if (counter < ratio) + nthr_n++; + else { + nthr_m++; + counter = -1; + } + } + counter++; + } + + // if nothing works out, then this should work + if ((nthr_m * nthr_n > nthrs)) { + + if (nthr_m <= nthr_n) { + nthr_m = (int)sqrt((double)nthrs); + if (nthr_m > utils::div_up(m, small_m)) + nthr_m = static_cast(utils::div_up(m, small_m)); + nthr_n = nthrs / nthr_m; + + while ((nthr_m > 1) && (nthr_m * nthr_n != nthrs)) { + nthr_m--; + nthr_n = nthrs / nthr_m; + } + } else { + nthr_n = (int)sqrt((double)nthrs); + if (nthr_n > utils::div_up(n, small_n)) + nthr_n = static_cast(utils::div_up(n, small_n)); + nthr_m = nthrs / nthr_n; + + while ((nthr_n > 1) && (nthr_m * nthr_n != nthrs)) { + nthr_n--; + nthr_m = nthrs / nthr_n; + } + } + } + + thread_m = utils::div_up(m, nthr_m) + small_m - 1; + thread_n = utils::div_up(n, nthr_n) + small_n - 1; + thread_m -= thread_m % small_m; + thread_n -= thread_n % small_n; + + if (thread_m * nthr_m > m) + nthr_m = static_cast(utils::div_up(m, thread_m)); + if (thread_n * nthr_n > n) + nthr_n = static_cast(utils::div_up(n, thread_n)); + + return std::make_tuple(nthr_m, nthr_n); +} + +template +static inline dim_t get_ld_padd(const dim_t x) { + return x != 1 ? utils::rnd_up(x, 2048 / sizeof(T)) + (64 / sizeof(T)) : 1; +} + +template +void prep_gemm_pack(bool do_a, int is_trans, dim_t nrows, dim_t ncols, + gemm_pack_storage_t *pack_dst) { + + auto ld = !is_trans ? get_ld_padd(nrows) : get_ld_padd(ncols); + auto td = !is_trans ? ncols : nrows; + + // TODO Do we need to use only one thread? + pack_dst->which() = do_a ? matrix_id::a : matrix_id::b; + pack_dst->setup(1); + pack_dst->threading().copy = copy_type::no_copy; + pack_dst->threading().nthrs_m = 1; + pack_dst->threading().nthrs_n = 1; + pack_dst->threading().nthrs_k = 1; + pack_dst->set_nocopy(0, is_trans, ld, td); + pack_dst->finalize(); +} + +template +dnnl_status_t pack_no_copy(const T *src, dim_t ld_src, dim_t nrows, dim_t ncols, + int trans_src, float alpha, gemm_pack_storage_t *dst_pack) { + + auto dst = dst_pack->matrix(0); + int trans_dst; + dim_t nrows_dst, ncols_dst; + dim_t ld_dst, td_dst; + + //constexpr bool is_f32 = data_traits_t::data_type == data_type::f32; + + if (!dst_pack->get_nocopy(0, trans_dst, ld_dst, td_dst)) + return dnnl_invalid_arguments; + + if (!trans_dst) { + nrows_dst = nrows; + ncols_dst = ncols; + } else { + nrows_dst = ncols; + ncols_dst = nrows; + } + + if (trans_src == trans_dst) { + parallel_nd(ncols_dst, [=](dim_t j) { + auto src_col = src + j * ld_src; + auto dst_col = dst + j * ld_dst; + + PRAGMA_OMP_SIMD() + for (dim_t i = 0; i < nrows_dst; i++) + //if (is_f32) + // dst_col[i] = alpha * src_col[i]; + //else + dst_col[i] = src_col[i]; + }); + } else { + // Naive code for now. + parallel_nd(ncols_dst, [=](dim_t j) { + auto src_col = src + j; + auto dst_col = dst + j * ld_dst; + + PRAGMA_OMP_SIMD() + for (dim_t i = 0; i < nrows_dst; i++) + //if (is_f32) + // dst_col[i] = alpha * src_col[i * ld_src]; + // else + dst_col[i] = src_col[i * ld_src]; + }); + } + + return dnnl_success; +} + +} // namespace gemm_utils +} // namespace ppc64 +} // namespace cpu +} // namespace impl +} // namespace dnnl + +#endif // CPU_PPC64_GEMM_GEMM_UTILS_HPP diff --git a/src/cpu/ppc64/ppc64_gemm_reorder.cpp b/src/cpu/ppc64/ppc64_gemm_reorder.cpp new file mode 100644 index 00000000000..ef2eff654d9 --- /dev/null +++ b/src/cpu/ppc64/ppc64_gemm_reorder.cpp @@ -0,0 +1,324 @@ +/******************************************************************************* +* Copyright 2022 IBM Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "cpu/ppc64/ppc64_gemm_reorder.hpp" +#include "cpu/reorder/simple_reorder.hpp" + +#include +#include +#include +#include // For thread sleep + +namespace dnnl { +namespace impl { +namespace cpu { +namespace ppc64 { + +using namespace dnnl::impl::cpu::q10n; + +typedef __vector signed long long vec_i64 __attribute__((aligned(8))); +typedef __vector short vec_i16 __attribute__((aligned(2))); +typedef __vector unsigned char vec_ut; +typedef __vector signed char vec_t; +typedef __vector signed short vec_short_t; +typedef __vector signed int vec_int_t; +typedef __vector float vec_float_t; + +status_t ppc64_matrixA_reorder_t::pd_t::init( + engine_t *engine, engine_t *src_engine, engine_t *dst_engine) { + using namespace status; + + using namespace format_tag; + + status_t status = cpu_reorder_pd_t::init(engine, src_engine, dst_engine); + if (status != success) return status; + + const memory_desc_wrapper id(src_md_), od(dst_md_); + + const int ndims = id.ndims(); + + const auto type_i = id.data_type(); + const auto type_o = od.data_type(); + + const auto in_strides = id.strides(); + const auto out_strides = od.strides(); + + const bool is_row_major = ((in_strides[0] == out_strides[0]) + && (in_strides[1] == out_strides[1]) + && (out_strides[1] == 1)) + ? true + : false; + const bool dt_ok = true && utils::one_of(type_i, data_type::f32) + && utils::one_of(type_o, data_type::u8, data_type::s8); + const bool args_ok = dt_ok && ndims == 2 && is_row_major; + + if (!args_ok) return invalid_arguments; + init_scratchpad(); + return status::success; +} + +status_t ppc64_matrixA_reorder_t::pd_t::create(reorder_pd_t **reorder_pd, + engine_t *engine, const primitive_attr_t *attr, engine_t *src_engine, + const memory_desc_t *src_md, engine_t *dst_engine, + const memory_desc_t *dst_md) { + auto _pd = make_unique_pd( + attr, src_engine->kind(), src_md, dst_engine->kind(), dst_md); + + if (_pd == nullptr) return status::out_of_memory; + CHECK(_pd->init(engine, src_engine, dst_engine)); + CHECK(_pd->init_scratchpad_md()); + return safe_ptr_assign(*reorder_pd, _pd.release()); +} + +typedef __vector unsigned int VecUInt; + +template +void kernel(InputType *inp, OutputType *out, int N, const float SrcScale, + const float DstScale, const int SrcZeroPoint, const int DstZeroPoint, + const float beta) { + //OutputType ZeroPoint) { + + constexpr int32_t MinimumValue = std::numeric_limits::min(); + constexpr int32_t MaximumValue = std::numeric_limits::max(); + + __vector float SrcScaleVector = vec_splats(SrcScale); + __vector float DstScaleVector = vec_splats(DstScale); + + __vector float MinimumValueVector = vec_splats(float(MinimumValue)); + __vector float MaximumValueVector = vec_splats(float(MaximumValue)); + __vector float SrcZeroPointVector = vec_splats(float(SrcZeroPoint)); + __vector float DstZeroPointVector = vec_splats(float(DstZeroPoint)); + + while (N >= 16) { + auto FloatVector0 = vec_xl(0, inp); + auto FloatVector1 = vec_xl(0, inp + 4); + auto FloatVector2 = vec_xl(0, inp + 8); + auto FloatVector3 = vec_xl(0, inp + 12); + + FloatVector0 = vec_sub(FloatVector0, SrcZeroPointVector); + FloatVector0 = vec_mul(FloatVector0, SrcScaleVector); + FloatVector1 = vec_sub(FloatVector1, SrcZeroPointVector); + FloatVector1 = vec_mul(FloatVector1, SrcScaleVector); + FloatVector2 = vec_sub(FloatVector2, SrcZeroPointVector); + FloatVector2 = vec_mul(FloatVector2, SrcScaleVector); + FloatVector3 = vec_sub(FloatVector3, SrcZeroPointVector); + FloatVector3 = vec_mul(FloatVector3, SrcScaleVector); + + if (beta) { + FloatVector0[0] += beta * (float)out[0]; + FloatVector0[1] += beta * (float)out[1]; + FloatVector0[2] += beta * (float)out[2]; + FloatVector0[3] += beta * (float)out[3]; + + FloatVector1[0] += beta * (float)out[4]; + FloatVector1[1] += beta * (float)out[5]; + FloatVector1[2] += beta * (float)out[6]; + FloatVector1[3] += beta * (float)out[7]; + + FloatVector2[0] += beta * (float)out[8]; + FloatVector2[1] += beta * (float)out[9]; + FloatVector2[2] += beta * (float)out[10]; + FloatVector2[3] += beta * (float)out[11]; + + FloatVector3[0] += beta * (float)out[12]; + FloatVector3[1] += beta * (float)out[13]; + FloatVector3[2] += beta * (float)out[14]; + FloatVector3[3] += beta * (float)out[15]; + } + FloatVector0 = vec_mul(FloatVector0, DstScaleVector); + FloatVector1 = vec_mul(FloatVector1, DstScaleVector); + FloatVector2 = vec_mul(FloatVector2, DstScaleVector); + FloatVector3 = vec_mul(FloatVector3, DstScaleVector); + + FloatVector0 = vec_round(FloatVector0); + FloatVector1 = vec_round(FloatVector1); + FloatVector2 = vec_round(FloatVector2); + FloatVector3 = vec_round(FloatVector3); + + FloatVector0 = vec_add(FloatVector0, DstZeroPointVector); + FloatVector1 = vec_add(FloatVector1, DstZeroPointVector); + FloatVector2 = vec_add(FloatVector2, DstZeroPointVector); + FloatVector3 = vec_add(FloatVector3, DstZeroPointVector); + + FloatVector0 = vec_max(FloatVector0, MinimumValueVector); + FloatVector1 = vec_max(FloatVector1, MinimumValueVector); + FloatVector2 = vec_max(FloatVector2, MinimumValueVector); + FloatVector3 = vec_max(FloatVector3, MinimumValueVector); + + FloatVector0 = vec_min(FloatVector0, MaximumValueVector); + FloatVector1 = vec_min(FloatVector1, MaximumValueVector); + FloatVector2 = vec_min(FloatVector2, MaximumValueVector); + FloatVector3 = vec_min(FloatVector3, MaximumValueVector); + + VecUInt IntegerVector0 = vec_ctu(FloatVector0, 0); + VecUInt IntegerVector1 = vec_ctu(FloatVector1, 0); + VecUInt IntegerVector2 = vec_ctu(FloatVector2, 0); + VecUInt IntegerVector3 = vec_ctu(FloatVector3, 0); + + auto ShortVector0 = vec_pack(IntegerVector0, IntegerVector1); + auto ShortVector1 = vec_pack(IntegerVector2, IntegerVector3); + auto CharVector = vec_pack(ShortVector0, ShortVector1); + + vec_xst(CharVector, 0, (uint8_t *)out); + out += 16; + inp += 16; + N -= 16; + } +#ifdef __MMA__ + while (N >= 4) { + auto FloatVector = vec_xl(0, inp); + FloatVector = vec_sub(FloatVector, SrcZeroPointVector); + FloatVector = vec_mul(FloatVector, SrcScaleVector); + + if (beta) { + FloatVector[0] += beta * (float)out[0]; + FloatVector[1] += beta * (float)out[1]; + FloatVector[2] += beta * (float)out[2]; + FloatVector[3] += beta * (float)out[3]; + } + FloatVector = vec_mul(FloatVector, DstScaleVector); + FloatVector = vec_round(FloatVector); + FloatVector = vec_add(FloatVector, DstZeroPointVector); + + FloatVector = vec_max(FloatVector, MinimumValueVector); + FloatVector = vec_min(FloatVector, MaximumValueVector); + auto IntegerVector = vec_ctu(FloatVector, 0); + + auto ShortVector = vec_pack(IntegerVector, vec_splats((uint32_t)0)); + auto CharVector = vec_pack(ShortVector, vec_splats((uint16_t)0)); + + vec_xst_len(CharVector, (uint8_t *)out, N); + + out += 4; + inp += 4; + N -= 4; + } + + if (N > 0) { + auto FloatVector = vec_xl_len(const_cast(inp), 4 * N); + FloatVector = vec_sub(FloatVector, SrcZeroPointVector); + FloatVector = vec_mul(FloatVector, SrcScaleVector); + if (beta) { + if (N == 1) { FloatVector[0] += beta * (float)out[0]; } + if (N == 2) { + FloatVector[0] += beta * (float)out[0]; + FloatVector[1] += beta * (float)out[1]; + } + if (N == 3) { + FloatVector[0] += beta * (float)out[0]; + FloatVector[1] += beta * (float)out[1]; + FloatVector[2] += beta * (float)out[2]; + } + } + FloatVector = vec_mul(FloatVector, DstScaleVector); + FloatVector = vec_round(FloatVector); + FloatVector = vec_add(FloatVector, DstZeroPointVector); + + FloatVector = vec_max(FloatVector, MinimumValueVector); + FloatVector = vec_min(FloatVector, MaximumValueVector); + auto IntegerVector = vec_ctu(FloatVector, 0); + + auto ShortVector = vec_pack(IntegerVector, vec_splats((uint32_t)0)); + auto CharVector = vec_pack(ShortVector, vec_splats((uint16_t)0)); + vec_xst_len(CharVector, (uint8_t *)out, N); + } +#else + // For Other than Power10 below code will run + // Remaining 1-15 elements + while (N > 0) { + float val = (*inp - SrcZeroPoint) * SrcScale; + if (beta) val += beta * *out; + val = val * DstScale + DstZeroPoint; + val = std::fmin( + std::fmax(val, float(MinimumValue)), float(MaximumValue)); + *out = uint8_t(std::nearbyint(val)); + inp++; + out++; + N -= 1; + } +#endif +} + +status_t ppc64_matrixA_reorder_t::execute_body(const exec_ctx_t &ctx) const { + using namespace utils; + + const auto input = CTX_IN_MEM(const float *, DNNL_ARG_FROM); + auto output = CTX_OUT_MEM(unsigned char *, DNNL_ARG_TO); + const auto &scratchpad = ctx.get_scratchpad_grantor(); + MAYBE_UNUSED(scratchpad); + const auto input_d = ctx.memory_mdw(DNNL_ARG_FROM, pd()->src_md()); + + DEFINE_ARG_SCALES_BUFFER_ATTR(pd()->attr(), src_scales, DNNL_ARG_FROM); + DEFINE_ARG_SCALES_BUFFER_ATTR(pd()->attr(), dst_scales_, DNNL_ARG_TO); + + int src_scales_mask, dst_scales_mask; + CHECK(get_scales_mask(pd()->attr(), &src_scales_mask, &dst_scales_mask)); + + int scales_mask = std::max(src_scales_mask, dst_scales_mask); + MAYBE_UNUSED(scales_mask); + + dim_t D_start, D_mask, D_rest; + pd()->get_D_values(input_d, scales_mask, &D_start, &D_mask, &D_rest); + + const float *dst_scales = pd()->precompute_scales( + scratchpad, pd()->attr(), D_mask, dst_scales_); + + const int32_t *src_zero_points = CTX_IN_MEM( + const int32_t *, DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_FROM); + int src_zp = src_zero_points ? src_zero_points[0] : 0; + + const int32_t *dst_zero_points = CTX_IN_MEM( + const int32_t *, DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_TO); + int dst_zp = dst_zero_points ? dst_zero_points[0] : 0; + + const float alpha = src_scales[0] * dst_scales[0]; + MAYBE_UNUSED(alpha); + const float beta = pd()->beta(); + + const auto &dims = input_d.dims(); + const auto in_strides = input_d.blocking_desc().strides; + const auto M = dims[0]; + const auto K = dims[1]; + + // Calculate block sizes + dim_t M_b = 16; + dim_t K_b = 64; + K_b = std::min(K_b, K); + + const dim_t num_M_blocks = (M + M_b - 1) / M_b; + const dim_t num_K_blocks = (K + K_b - 1) / K_b; + + parallel_nd(num_M_blocks, num_K_blocks, [&](dim_t mb, dim_t kb) { + dim_t M_start = mb * M_b; + dim_t M_end = nstl::min(M_start + M_b, M); + dim_t K_start = kb * K_b; + dim_t K_end = nstl::min(K_start + K_b, K); + // Iterate over the block + for (dim_t i = M_start; i < M_end; ++i) { + kernel( + input + i * in_strides[0] + K_start, + output + i * in_strides[0] + K_start, K_end - K_start, + src_scales[0], dst_scales[0], src_zp, dst_zp, beta); + } + }); + + return status::success; +} + +} // namespace ppc64 +} // namespace cpu +} // namespace impl +} // namespace dnnl diff --git a/src/cpu/ppc64/ppc64_gemm_reorder.hpp b/src/cpu/ppc64/ppc64_gemm_reorder.hpp new file mode 100644 index 00000000000..81518df0c68 --- /dev/null +++ b/src/cpu/ppc64/ppc64_gemm_reorder.hpp @@ -0,0 +1,82 @@ +/******************************************************************************* +* Copyright 2022 IBM Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_PPC64_PPC64_GEMM_REORDER_HPP +#define CPU_PPC64_PPC64_GEMM_REORDER_HPP + +#include "common/bfloat16.hpp" +#include "common/c_types_map.hpp" +#include "common/dnnl_thread.hpp" +#include "common/math_utils.hpp" +#include "common/memory_tracking.hpp" +#include "common/primitive.hpp" +#include "common/primitive_attr.hpp" +#include "common/tag_traits.hpp" +#include "common/type_helpers.hpp" +#include "common/utils.hpp" +#include "common/verbose.hpp" + +#include "cpu/cpu_primitive.hpp" +#include "cpu/reorder/cpu_reorder_pd.hpp" + +#include "cpu/simple_q10n.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace ppc64 { + +using namespace format_tag; + +using bd = block_dim_t; +using ib = inner_blk_t; + +struct ppc64_matrixA_reorder_t : public primitive_t { + + struct pd_t : public cpu_reorder_pd_t { + using cpu_reorder_pd_t::cpu_reorder_pd_t; + + DECLARE_COMMON_PD_T("ppc64_matrixA_reorder_t", ppc64_matrixA_reorder_t); + + status_t init( + engine_t *engine, engine_t *src_engine, engine_t *dst_engine); + + private: + static status_t create(reorder_pd_t **reorder_pd, engine_t *engine, + const primitive_attr_t *attr, engine_t *src_engine, + const memory_desc_t *src_md, engine_t *dst_engine, + const memory_desc_t *dst_md); + + void init_scratchpad() {} + friend dnnl::impl::impl_list_item_t; + }; + ppc64_matrixA_reorder_t(const pd_t *apd) : primitive_t(apd) {} + + status_t init(engine_t *engine) override { return status::success; } + +private: + status_t execute_body(const exec_ctx_t &ctx) const; + status_t execute(const exec_ctx_t &ctx) const override { + return execute_body(ctx); + } + const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } +}; + +} // namespace ppc64 +} // namespace cpu +} // namespace impl +} // namespace dnnl +#endif diff --git a/src/cpu/ppc64/ppc64_gemm_s8x8s32.hpp b/src/cpu/ppc64/ppc64_gemm_s8x8s32.hpp index a0756862a17..93376146034 100644 --- a/src/cpu/ppc64/ppc64_gemm_s8x8s32.hpp +++ b/src/cpu/ppc64/ppc64_gemm_s8x8s32.hpp @@ -23,14 +23,13 @@ namespace dnnl { namespace impl { -uint64_t mker; - typedef __vector signed long long vec_i64 __attribute__((aligned(8))); typedef __vector short vec_i16 __attribute__((aligned(2))); typedef __vector unsigned char vec_t; typedef __vector signed char vec_st; +typedef __vector_pair vecp_t; -int pack_N16_16bit(dim_t k, dim_t m, short *a, dim_t lda, short *ap) { +inline int pack_N16_16bit(dim_t k, dim_t m, short *a, dim_t lda, short *ap) { int32_t i, j; int32_t kcell, cell, koff, moff, krows, mrows, block4, block2, mcell, chunk4count, k8, m4, m16; @@ -195,7 +194,7 @@ int pack_N16_16bit(dim_t k, dim_t m, short *a, dim_t lda, short *ap) { return 0; } -int pack_T16_16bit(dim_t k, dim_t m, short *a, dim_t lda, short *ap) { +inline int pack_T16_16bit(dim_t k, dim_t m, short *a, dim_t lda, short *ap) { int32_t i, j; int32_t kcell, cell, koff, moff, krows, mrows, block4, block2, mcell, chunk4count, k4, m8, m16; @@ -359,7 +358,7 @@ int pack_T16_16bit(dim_t k, dim_t m, short *a, dim_t lda, short *ap) { return 0; } -int pack_T8_16bit(dim_t k, dim_t n, short *b, dim_t ldb, short *bp) { +inline int pack_T8_16bit(dim_t k, dim_t n, short *b, dim_t ldb, short *bp) { int32_t i, j; int32_t kcell, cell, koff, noff, krows, k4, n8, n16; int32_t n_cap = (n + 3) & ~3; @@ -493,7 +492,7 @@ int pack_T8_16bit(dim_t k, dim_t n, short *b, dim_t ldb, short *bp) { return 0; } -int pack_N8_16bit(dim_t k, dim_t n, short *b, dim_t ldb, short *bp) { +inline int pack_N8_16bit(dim_t k, dim_t n, short *b, dim_t ldb, short *bp) { int32_t i, j; int32_t kcell, cell, koff, noff, krows, k8, k16, n4, n8; int32_t n_cap = (n + 3) & ~3; @@ -675,7 +674,723 @@ int pack_N8_16bit(dim_t k, dim_t n, short *b, dim_t ldb, short *bp) { return 0; } -int pack_T16_8bit(dim_t k, dim_t m, const int8_t *a, dim_t lda, int8_t *ap) { +template +inline int pack_T16_8bit_V2(dim_t K_dim, dim_t M_dim, const int8_t *A, + dim_t lda, int8_t *packA, int *row_sum) { + + vec_t mask = {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15}; + + while (M_dim >= 16) { + + const int8_t *a = A; + + VecType V0, V1, V2, V3; + VecType D01A, D01B, D23A, D23B; + VecType D0, D1, D2, D3; + + vec_t swizA = {0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23}; + vec_t swizB = { + 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31}; + vec_t swizL = {0, 1, 16, 17, 2, 3, 18, 19, 4, 5, 20, 21, 6, 7, 22, 23}; + vec_t swizR = { + 8, 9, 24, 25, 10, 11, 26, 27, 12, 13, 28, 29, 14, 15, 30, 31}; + + __vector signed int vsum = {0}; + __vector signed int vsum1 = {0}; + __vector signed int vsum2 = {0}; + __vector signed int vsum3 = {0}; + + size_t y = K_dim; + while (y >= 4) { + V0 = *(VecType *)&a[lda * 0]; + V1 = *(VecType *)&a[lda * 1]; + V2 = *(VecType *)&a[lda * 2]; + V3 = *(VecType *)&a[lda * 3]; + + D01A = vec_perm(V0, V1, swizA); + D01B = vec_perm(V0, V1, swizB); + D23A = vec_perm(V2, V3, swizA); + D23B = vec_perm(V2, V3, swizB); + D0 = vec_perm(D01A, D23A, swizL); + D1 = vec_perm(D01A, D23A, swizR); + D2 = vec_perm(D01B, D23B, swizL); + D3 = vec_perm(D01B, D23B, swizR); + + *(VecType *)&packA[0] = D0; + *(VecType *)&packA[16] = D1; + *(VecType *)&packA[32] = D2; + *(VecType *)&packA[48] = D3; + + vsum = vec_sum4s(D0, vsum); + vsum1 = vec_sum4s(D1, vsum1); + vsum2 = vec_sum4s(D2, vsum2); + vsum3 = vec_sum4s(D3, vsum3); + + packA += 64; + y -= 4; + a += lda * 4; + } + if (y >= 1 && y <= 3) { + V0 = reinterpret_cast(vec_splats(uint8_t(0))); + V1 = reinterpret_cast(vec_splats(uint8_t(0))); + V2 = reinterpret_cast(vec_splats(uint8_t(0))); + V3 = reinterpret_cast(vec_splats(uint8_t(0))); + + V0 = *(VecType *)&a[lda * 0]; + if (y == 2) { V1 = *(VecType *)&a[lda * 1]; } + if (y == 3) { + V1 = *(VecType *)&a[lda * 1]; + V2 = *(VecType *)&a[lda * 2]; + } + + D01A = vec_perm(V0, V1, swizA); + D01B = vec_perm(V0, V1, swizB); + D23A = vec_perm(V2, V3, swizA); + D23B = vec_perm(V2, V3, swizB); + D0 = vec_perm(D01A, D23A, swizL); + D1 = vec_perm(D01A, D23A, swizR); + D2 = vec_perm(D01B, D23B, swizL); + D3 = vec_perm(D01B, D23B, swizR); + + *(VecType *)&packA[0] = D0; + *(VecType *)&packA[16] = D1; + *(VecType *)&packA[32] = D2; + *(VecType *)&packA[48] = D3; + + vsum = vec_sum4s(D0, vsum); + vsum1 = vec_sum4s(D1, vsum1); + vsum2 = vec_sum4s(D2, vsum2); + vsum3 = vec_sum4s(D3, vsum3); + packA += 64; + y -= 4; + a += lda * 4; + } + + row_sum[0] = vsum[0]; + row_sum[1] = vsum[1]; + row_sum[2] = vsum[2]; + row_sum[3] = vsum[3]; + row_sum[4] = vsum1[0]; + row_sum[5] = vsum1[1]; + row_sum[6] = vsum1[2]; + row_sum[7] = vsum1[3]; + + row_sum[8] = vsum2[0]; + row_sum[9] = vsum2[1]; + row_sum[10] = vsum2[2]; + row_sum[11] = vsum2[3]; + row_sum[12] = vsum3[0]; + row_sum[13] = vsum3[1]; + row_sum[14] = vsum3[2]; + row_sum[15] = vsum3[3]; + + row_sum += 16; + A += 16; + M_dim -= 16; + } + if (M_dim > 12 && M_dim < 16) { + const int8_t *a = A; + size_t y = K_dim; + size_t tail_M = M_dim - 12; + + __vector signed int vsum = {0}; + __vector signed int vsum1 = {0}; + __vector signed int vsum2 = {0}; + __vector signed int vsum3 = {0}; + + while (y >= 4) { + + int b1 = *reinterpret_cast(&a[0]); + int b2 = *reinterpret_cast(&a[lda * 1]); + int b3 = *reinterpret_cast(&a[lda * 2]); + int b4 = *reinterpret_cast(&a[lda * 3]); + __vector int vb = {b1, b2, b3, b4}; + VecType vx = vec_perm(reinterpret_cast(vb), + reinterpret_cast(vb), mask); + vsum = vec_sum4s(vx, vsum); + *reinterpret_cast(&packA[0]) + = reinterpret_cast(vx); + packA += 16; + + b1 = *reinterpret_cast(&a[4]); + b2 = *reinterpret_cast(&a[lda * 1 + 4]); + b3 = *reinterpret_cast(&a[lda * 2 + 4]); + b4 = *reinterpret_cast(&a[lda * 3 + 4]); + __vector int vb1 = {b1, b2, b3, b4}; + vx = vec_perm(reinterpret_cast(vb1), + reinterpret_cast(vb1), mask); + vsum1 = vec_sum4s(vx, vsum1); + *reinterpret_cast(&packA[0]) + = reinterpret_cast(vx); + packA += 16; + + b1 = *reinterpret_cast(&a[8]); + b2 = *reinterpret_cast(&a[lda * 1 + 8]); + b3 = *reinterpret_cast(&a[lda * 2 + 8]); + b4 = *reinterpret_cast(&a[lda * 3 + 8]); + __vector int vb2 = {b1, b2, b3, b4}; + vx = vec_perm(reinterpret_cast(vb2), + reinterpret_cast(vb2), mask); + vsum2 = vec_sum4s(vx, vsum2); + *reinterpret_cast(&packA[0]) + = reinterpret_cast(vx); + packA += 16; + + if (tail_M >= 1) { + VecType va1 = reinterpret_cast(vec_splats(uint8_t(0))); + va1[0] = a[12]; + va1[1] = a[lda * 1 + 12]; + va1[2] = a[lda * 2 + 12]; + va1[3] = a[lda * 3 + 12]; + + if (tail_M == 2) { + va1[4] = a[13]; + va1[5] = a[lda * 1 + 13]; + va1[6] = a[lda * 2 + 13]; + va1[7] = a[lda * 3 + 13]; + } + if (tail_M == 3) { + va1[4] = a[13]; + va1[5] = a[lda * 1 + 13]; + va1[6] = a[lda * 2 + 13]; + va1[7] = a[lda * 3 + 13]; + + va1[8] = a[14]; + va1[9] = a[lda * 1 + 14]; + va1[10] = a[lda * 2 + 14]; + va1[11] = a[lda * 3 + 14]; + } + *reinterpret_cast(&packA[0]) + = reinterpret_cast(va1); + vsum3 = vec_sum4s(va1, vsum3); + packA += 16; + } + y -= 4; + a += lda * 4; + } + if (y >= 1 && y <= 3) { + VecType va1 = reinterpret_cast(vec_splats(uint8_t(0))); + VecType va2 = reinterpret_cast(vec_splats(uint8_t(0))); + VecType va3 = reinterpret_cast(vec_splats(uint8_t(0))); + + va1[0] = a[0]; + va1[4] = a[1]; + va1[8] = a[2]; + va1[12] = a[3]; + + va2[0] = a[4]; + va2[4] = a[5]; + va2[8] = a[6]; + va2[12] = a[7]; + + va3[0] = a[8]; + va3[4] = a[9]; + va3[8] = a[10]; + va3[12] = a[11]; + + if (y == 2) { + va1[1] = a[lda]; + va1[5] = a[lda + 1]; + va1[9] = a[lda + 2]; + va1[13] = a[lda + 3]; + + va2[1] = a[lda + 4]; + va2[5] = a[lda + 5]; + va2[9] = a[lda + 6]; + va2[13] = a[lda + 7]; + + va3[1] = a[lda + 8]; + va3[5] = a[lda + 9]; + va3[9] = a[lda + 10]; + va3[13] = a[lda + 11]; + } + if (y == 3) { + va1[1] = a[lda]; + va1[5] = a[lda + 1]; + va1[9] = a[lda + 2]; + va1[13] = a[lda + 3]; + + va2[1] = a[lda + 4]; + va2[5] = a[lda + 5]; + va2[9] = a[lda + 6]; + va2[13] = a[lda + 7]; + + va3[1] = a[lda + 8]; + va3[5] = a[lda + 9]; + va3[9] = a[lda + 10]; + va3[13] = a[lda + 11]; + + va1[2] = a[lda * 2]; + va1[6] = a[lda * 2 + 1]; + va1[10] = a[lda * 2 + 2]; + va1[14] = a[lda * 2 + 3]; + + va2[2] = a[lda * 2 + 4]; + va2[6] = a[lda * 2 + 5]; + va2[10] = a[lda * 2 + 6]; + va2[14] = a[lda * 2 + 7]; + + va3[2] = a[lda * 2 + 8]; + va3[6] = a[lda * 2 + 9]; + va3[10] = a[lda * 2 + 10]; + va3[14] = a[lda * 2 + 11]; + } + *reinterpret_cast(&packA[0]) + = reinterpret_cast(va1); + vsum = vec_sum4s(va1, vsum); + vsum1 = vec_sum4s(va2, vsum1); + vsum2 = vec_sum4s(va3, vsum2); + *reinterpret_cast(&packA[16]) + = reinterpret_cast(va2); + *reinterpret_cast(&packA[32]) + = reinterpret_cast(va3); + packA += 48; + a += 12; + + if (tail_M > 0) { + VecType va1 = reinterpret_cast(vec_splats(uint8_t(0))); + + if (tail_M == 1) { + va1[0] = a[0]; + + if (y == 2) { va1[1] = a[lda]; } + if (y == 3) { + va1[1] = a[lda]; + va1[2] = a[lda * 2]; + } + } + if (tail_M == 2) { + va1[0] = a[0]; + va1[4] = a[1]; + if (y == 2) { + va1[1] = a[lda]; + va1[5] = a[lda + 1]; + } + if (y == 3) { + va1[1] = a[lda]; + va1[5] = a[lda + 1]; + + va1[2] = a[lda * 2]; + va1[6] = a[lda * 2 + 1]; + } + } + if (tail_M == 3) { + va1[0] = a[0]; + va1[4] = a[1]; + va1[8] = a[2]; + if (y == 2) { + va1[1] = a[lda]; + va1[5] = a[lda + 1]; + va1[9] = a[lda + 2]; + } + if (y == 3) { + va1[1] = a[lda]; + va1[5] = a[lda + 1]; + va1[9] = a[lda + 2]; + + va1[2] = a[lda * 2]; + va1[6] = a[lda * 2 + 1]; + va1[10] = a[lda * 2 + 2]; + } + } + *reinterpret_cast(&packA[0]) + = reinterpret_cast(va1); + vsum3 = vec_sum4s(va1, vsum3); + packA += 16; + } + } + row_sum[0] = vsum[0]; + row_sum[1] = vsum[1]; + row_sum[2] = vsum[2]; + row_sum[3] = vsum[3]; + row_sum[4] = vsum1[0]; + row_sum[5] = vsum1[1]; + row_sum[6] = vsum1[2]; + row_sum[7] = vsum1[3]; + + row_sum[8] = vsum2[0]; + row_sum[9] = vsum2[1]; + row_sum[10] = vsum2[2]; + row_sum[11] = vsum2[3]; + + if (tail_M == 1) { row_sum[12] = vsum3[0]; } + if (tail_M == 2) { + row_sum[12] = vsum3[0]; + row_sum[13] = vsum3[1]; + } + if (tail_M == 3) { + row_sum[12] = vsum3[0]; + row_sum[13] = vsum3[1]; + row_sum[14] = vsum3[2]; + } + M_dim = 0; + } + + while (M_dim >= 8) { + const int8_t *a = A; + __vector signed int vsum = {0}; + __vector signed int vsum1 = {0}; + size_t y = K_dim; + while (y >= 8) { + VecType V0, V1, V2, V3; + VecType D01A, D01B, D23A, D23B; + VecType D0, D1, D2, D3; + vec_t swizA + = {0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23}; + vec_t swizB = {8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, + 15, 31}; + vec_t swizL + = {0, 1, 16, 17, 2, 3, 18, 19, 4, 5, 20, 21, 6, 7, 22, 23}; + vec_t swizR = {8, 9, 24, 25, 10, 11, 26, 27, 12, 13, 28, 29, 14, 15, + 30, 31}; + + *(signed long long *)&V0[0] = *(signed long long *)&a[0]; + *(signed long long *)&V1[0] = *(signed long long *)&a[lda * 1]; + *(signed long long *)&V2[0] = *(signed long long *)&a[lda * 2]; + *(signed long long *)&V3[0] = *(signed long long *)&a[lda * 3]; + *(signed long long *)&V0[8] = *(signed long long *)&a[lda * 4]; + *(signed long long *)&V1[8] = *(signed long long *)&a[lda * 5]; + *(signed long long *)&V2[8] = *(signed long long *)&a[lda * 6]; + *(signed long long *)&V3[8] = *(signed long long *)&a[lda * 7]; + + D01A = vec_perm(V0, V1, swizA); + D01B = vec_perm(V0, V1, swizB); + D23A = vec_perm(V2, V3, swizA); + D23B = vec_perm(V2, V3, swizB); + D0 = vec_perm(D01A, D23A, swizL); + D1 = vec_perm(D01A, D23A, swizR); + D2 = vec_perm(D01B, D23B, swizL); + D3 = vec_perm(D01B, D23B, swizR); + + *(VecType *)&packA[0] = D0; + *(VecType *)&packA[16] = D1; + *(VecType *)&packA[32] = D2; + *(VecType *)&packA[48] = D3; + + vsum = vec_sum4s(D0, vsum); + vsum = vec_sum4s(D2, vsum); + vsum1 = vec_sum4s(D1, vsum1); + vsum1 = vec_sum4s(D3, vsum1); + packA += 64; + y -= 8; + a += lda * 8; + } + if (y >= 4) { + int b1 = *reinterpret_cast(&a[0]); + int b2 = *reinterpret_cast(&a[lda]); + int b3 = *reinterpret_cast(&a[lda * 2]); + int b4 = *reinterpret_cast(&a[lda * 3]); + __vector int vb = {b1, b2, b3, b4}; + VecType vx = vec_perm(reinterpret_cast(vb), + reinterpret_cast(vb), mask); + *reinterpret_cast(&packA[0]) + = reinterpret_cast(vx); + packA += 16; + + vsum = vec_sum4s(vx, vsum); + + b1 = *reinterpret_cast(&a[4]); + b2 = *reinterpret_cast(&a[lda + 4]); + b3 = *reinterpret_cast(&a[lda * 2 + 4]); + b4 = *reinterpret_cast(&a[lda * 3 + 4]); + + __vector int vb1 = {b1, b2, b3, b4}; + VecType vx1 = vec_perm(reinterpret_cast(vb1), + reinterpret_cast(vb1), mask); + *reinterpret_cast(&packA[0]) + = reinterpret_cast(vx1); + packA += 16; + vsum1 = vec_sum4s(vx1, vsum1); + y -= 4; + a += lda * 4; + } + + if (y >= 1 && y <= 3) { + VecType va1 = reinterpret_cast(vec_splats(uint8_t(0))); + VecType va2 = reinterpret_cast(vec_splats(uint8_t(0))); + va1[0] = a[0]; + va1[4] = a[1]; + va1[8] = a[2]; + va1[12] = a[3]; + + va2[0] = a[4]; + va2[4] = a[5]; + va2[8] = a[6]; + va2[12] = a[7]; + if (y == 2) { + va1[1] = a[lda]; + va1[5] = a[lda + 1]; + va1[9] = a[lda + 2]; + va1[13] = a[lda + 3]; + + va2[1] = a[lda + 4]; + va2[5] = a[lda + 5]; + va2[9] = a[lda + 6]; + va2[13] = a[lda + 7]; + } + if (y == 3) { + va1[1] = a[lda]; + va1[5] = a[lda + 1]; + va1[9] = a[lda + 2]; + va1[13] = a[lda + 3]; + + va2[1] = a[lda + 4]; + va2[5] = a[lda + 5]; + va2[9] = a[lda + 6]; + va2[13] = a[lda + 7]; + + va1[2] = a[lda * 2]; + va1[6] = a[lda * 2 + 1]; + va1[10] = a[lda * 2 + 2]; + va1[14] = a[lda * 2 + 3]; + + va2[2] = a[lda * 2 + 4]; + va2[6] = a[lda * 2 + 5]; + va2[10] = a[lda * 2 + 6]; + va2[14] = a[lda * 2 + 7]; + } + *reinterpret_cast(&packA[0]) + = reinterpret_cast(va1); + *reinterpret_cast(&packA[16]) + = reinterpret_cast(va2); + packA += 32; + vsum = vec_sum4s(va1, vsum); + vsum1 = vec_sum4s(va2, vsum1); + } + row_sum[0] = vsum[0]; + row_sum[1] = vsum[1]; + row_sum[2] = vsum[2]; + row_sum[3] = vsum[3]; + + row_sum[4] = vsum1[0]; + row_sum[5] = vsum1[1]; + row_sum[6] = vsum1[2]; + row_sum[7] = vsum1[3]; + row_sum += 8; + A += 8; + M_dim -= 8; + } + + if (M_dim < 8 && M_dim >= 4) { + const int8_t *a = A; + __vector signed int vsum = {0}; + __vector signed int vsum1 = {0}; + size_t y = K_dim; + size_t tail_M = M_dim - 4; + while (y >= 4) { + int b1 = *reinterpret_cast(&a[0]); + int b2 = *reinterpret_cast(&a[lda]); + int b3 = *reinterpret_cast(&a[lda * 2]); + int b4 = *reinterpret_cast(&a[lda * 3]); + __vector int vb = {b1, b2, b3, b4}; + VecType vx = vec_perm(reinterpret_cast(vb), + reinterpret_cast(vb), mask); + *reinterpret_cast(&packA[0]) + = reinterpret_cast(vx); + packA += 16; + vsum = vec_sum4s(vx, vsum); + + if (tail_M >= 1) { + VecType va1 = reinterpret_cast(vec_splats(uint8_t(0))); + va1[0] = a[4]; + va1[1] = a[lda + 4]; + va1[2] = a[lda * 2 + 4]; + va1[3] = a[lda * 3 + 4]; + if (tail_M == 2) { + va1[4] = a[5]; + va1[5] = a[lda + 5]; + va1[6] = a[lda * 2 + 5]; + va1[7] = a[lda * 3 + 5]; + } + if (tail_M == 3) { + va1[4] = a[5]; + va1[5] = a[lda + 5]; + va1[6] = a[lda * 2 + 5]; + va1[7] = a[lda * 3 + 5]; + va1[8] = a[6]; + va1[9] = a[lda + 6]; + va1[10] = a[lda * 2 + 6]; + va1[11] = a[lda * 3 + 6]; + } + *reinterpret_cast(&packA[0]) + = reinterpret_cast(va1); + packA += 16; + vsum1 = vec_sum4s(va1, vsum1); + } + a += lda * 4; + y -= 4; + } + + if (y >= 1 && y <= 3) { + + int a1 = 0, a2 = 0, a3 = 0, a4 = 0; + a1 = *reinterpret_cast(&a[0]); + if (y == 2) { a2 = *reinterpret_cast(&a[lda]); } + if (y == 3) { + a2 = *reinterpret_cast(&a[lda]); + a3 = *reinterpret_cast(&a[lda * 2]); + } + __vector int vb = {a1, a2, a3, a4}; + VecType vx = vec_perm(reinterpret_cast(vb), + reinterpret_cast(vb), mask); + + *reinterpret_cast(&packA[0]) + = reinterpret_cast(vx); + packA += 16; + vsum = vec_sum4s(vx, vsum); + + if (tail_M >= 1) { + VecType va1 = reinterpret_cast(vec_splats(uint8_t(0))); + if (y == 1) { + va1[0] = a[4]; + if (tail_M == 2) { va1[4] = a[5]; } + if (tail_M == 3) { + va1[4] = a[5]; + va1[8] = a[6]; + } + } + if (y == 2) { + va1[0] = a[4]; + va1[1] = a[lda + 4]; + if (tail_M == 2) { + va1[4] = a[5]; + va1[5] = a[lda + 5]; + } + if (tail_M == 3) { + va1[4] = a[5]; + va1[5] = a[lda + 5]; + va1[8] = a[6]; + va1[9] = a[lda + 6]; + } + } + if (y == 3) { + va1[0] = a[4]; + va1[1] = a[lda + 4]; + va1[2] = a[lda * 2 + 4]; + if (tail_M == 2) { + va1[4] = a[5]; + va1[5] = a[lda + 5]; + va1[6] = a[lda * 2 + 5]; + } + if (tail_M == 3) { + va1[4] = a[5]; + va1[5] = a[lda + 5]; + va1[6] = a[lda * 2 + 5]; + va1[8] = a[6]; + va1[9] = a[lda + 6]; + va1[10] = a[lda * 2 + 6]; + } + } + *reinterpret_cast(&packA[0]) + = reinterpret_cast(va1); + packA += 4; + vsum1 = vec_sum4s(va1, vsum1); + } + } + row_sum[0] = vsum[0]; + row_sum[1] = vsum[1]; + row_sum[2] = vsum[2]; + row_sum[3] = vsum[3]; + row_sum += 4; + if (tail_M > 0) { + row_sum[0] = vsum1[0]; + row_sum[1] = vsum1[1]; + row_sum[2] = vsum1[2]; + row_sum[3] = vsum1[3]; + row_sum += 4; + } + } + if (M_dim >= 1 && M_dim <= 3) { + const int8_t *a = A; + __vector signed int vsum = {0}; + size_t y = K_dim; + while (y >= 4) { + VecType va1 = reinterpret_cast(vec_splats(uint8_t(0))); + va1[0] = a[0]; + va1[1] = a[lda]; + va1[2] = a[lda * 2]; + va1[3] = a[lda * 3]; + if (M_dim == 2) { + va1[4] = a[1]; + va1[5] = a[lda + 1]; + va1[6] = a[lda * 2 + 1]; + va1[7] = a[lda * 3 + 1]; + } + if (M_dim == 3) { + va1[4] = a[1]; + va1[5] = a[lda + 1]; + va1[6] = a[lda * 2 + 1]; + va1[7] = a[lda * 3 + 1]; + va1[8] = a[2]; + va1[9] = a[lda + 2]; + va1[10] = a[lda * 2 + 2]; + va1[11] = a[lda * 3 + 2]; + } + *reinterpret_cast(&packA[0]) + = reinterpret_cast(va1); + vsum = vec_sum4s(va1, vsum); + packA += 16; + a += lda * 4; + y -= 4; + } + if (y >= 1 && y <= 3) { + VecType va1 = reinterpret_cast(vec_splats(uint8_t(0))); + if (y == 1) { + va1[0] = a[0]; + if (M_dim == 2) { va1[4] = a[1]; } + if (M_dim == 3) { + va1[4] = a[1]; + va1[8] = a[2]; + } + } + if (y == 2) { + va1[0] = a[0]; + va1[1] = a[lda]; + if (M_dim == 2) { + va1[4] = a[1]; + va1[5] = a[lda + 1]; + } + if (M_dim == 3) { + va1[4] = a[1]; + va1[5] = a[lda + 1]; + va1[8] = a[2]; + va1[9] = a[lda + 2]; + } + } + if (y == 3) { + va1[0] = a[0]; + va1[1] = a[lda]; + va1[2] = a[lda * 2]; + if (M_dim == 2) { + va1[4] = a[1]; + va1[5] = a[lda + 1]; + va1[6] = a[lda * 2 + 1]; + } + if (M_dim == 3) { + va1[4] = a[1]; + va1[5] = a[lda + 1]; + va1[6] = a[lda * 2 + 1]; + va1[8] = a[2]; + va1[9] = a[lda + 2]; + va1[10] = a[lda * 2 + 2]; + } + } + *reinterpret_cast(&packA[0]) + = reinterpret_cast(va1); + vsum = vec_sum4s(va1, vsum); + packA += 16; + } + row_sum[0] = vsum[0]; + row_sum[1] = vsum[1]; + row_sum[2] = vsum[2]; + row_sum[3] = vsum[3]; + row_sum += 4; + } + return 0; +} + +inline int pack_T16_8bit( + dim_t k, dim_t m, const int8_t *a, dim_t lda, int8_t *ap) { int32_t i, j; int32_t m_cap = (m + 3) & ~3; int32_t k_cap = (k + 3) & ~3; @@ -810,11 +1525,1987 @@ int pack_T16_8bit(dim_t k, dim_t m, const int8_t *a, dim_t lda, int8_t *ap) { ap[16 * cell + 4 * moff + koff] = 0; } } - + + return 0; +} + +template +int pack_N8_8bit_V2_lxvp( + dim_t K_dim, dim_t N_dim, const uint8_t *B, dim_t ldb, uint8_t *Bp) { + + while (N_dim >= 8) { + uint8_t *b = const_cast(B); + size_t y = K_dim; + while (y >= 32) { + __vector_pair row1, row2, row3, row4, row5, row6, row7, row8; + VecType r1[2], r2[2], r3[2], r4[2], r5[2], r6[2], r7[2], r8[2]; + VecType V0, V1, V2, V3; + VecType D0, D1, D2, D3; + VecType swizA = { + 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27}; + VecType swizB = { + 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31}; + + row1 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[0])); + row2 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[ldb])); + row3 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[ldb * 2])); + row4 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[ldb * 3])); + row5 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[ldb * 4])); + row6 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[ldb * 5])); + row7 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[ldb * 6])); + row8 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[ldb * 7])); + + __builtin_vsx_disassemble_pair(r1, &row1); + __builtin_vsx_disassemble_pair(r2, &row2); + __builtin_vsx_disassemble_pair(r3, &row3); + __builtin_vsx_disassemble_pair(r4, &row4); + __builtin_vsx_disassemble_pair(r5, &row5); + __builtin_vsx_disassemble_pair(r6, &row6); + __builtin_vsx_disassemble_pair(r7, &row7); + __builtin_vsx_disassemble_pair(r8, &row8); + + // First 4 Rows and First 16 columns + V0 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r1[0]), + reinterpret_cast<__vector int>(r2[0]))); + V1 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r3[0]), + reinterpret_cast<__vector int>(r4[0]))); + V2 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r1[0]), + reinterpret_cast<__vector int>(r2[0]))); + V3 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r3[0]), + reinterpret_cast<__vector int>(r4[0]))); + + D0 = vec_xxpermdi(V0, V1, 0); + D1 = vec_xxpermdi(V2, V3, 0); + D2 = vec_xxpermdi(V0, V1, 3); + D3 = vec_xxpermdi(V2, V3, 3); + + *(VecType *)&Bp[0] = D0; + *(VecType *)&Bp[32] = D1; + *(VecType *)&Bp[64] = D2; + *(VecType *)&Bp[96] = D3; + + // Next (ldb * 4) 4 Rows and First 16 columns + V0 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r5[0]), + reinterpret_cast<__vector int>(r6[0]))); + V1 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r7[0]), + reinterpret_cast<__vector int>(r8[0]))); + V2 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r5[0]), + reinterpret_cast<__vector int>(r6[0]))); + V3 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r7[0]), + reinterpret_cast<__vector int>(r8[0]))); + + D0 = vec_xxpermdi(V0, V1, 0); + D1 = vec_xxpermdi(V2, V3, 0); + D2 = vec_xxpermdi(V0, V1, 3); + D3 = vec_xxpermdi(V2, V3, 3); + + *(VecType *)&Bp[16] = D0; + *(VecType *)&Bp[48] = D1; + *(VecType *)&Bp[80] = D2; + *(VecType *)&Bp[112] = D3; + + // First 4 Rows and Second 16 columns + + V0 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r1[1]), + reinterpret_cast<__vector int>(r2[1]))); + V1 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r3[1]), + reinterpret_cast<__vector int>(r4[1]))); + V2 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r1[1]), + reinterpret_cast<__vector int>(r2[1]))); + V3 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r3[1]), + reinterpret_cast<__vector int>(r4[1]))); + + D0 = vec_xxpermdi(V0, V1, 0); + D1 = vec_xxpermdi(V2, V3, 0); + D2 = vec_xxpermdi(V0, V1, 3); + D3 = vec_xxpermdi(V2, V3, 3); + + *(VecType *)&Bp[128] = D0; + *(VecType *)&Bp[160] = D1; + *(VecType *)&Bp[192] = D2; + *(VecType *)&Bp[224] = D3; + + // Next (ldb * 4) 4 Rows and Second 16 columns + V0 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r5[1]), + reinterpret_cast<__vector int>(r6[1]))); + V1 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r7[1]), + reinterpret_cast<__vector int>(r8[1]))); + V2 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r5[1]), + reinterpret_cast<__vector int>(r6[1]))); + V3 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r7[1]), + reinterpret_cast<__vector int>(r8[1]))); + + D0 = vec_xxpermdi(V0, V1, 0); + D1 = vec_xxpermdi(V2, V3, 0); + D2 = vec_xxpermdi(V0, V1, 3); + D3 = vec_xxpermdi(V2, V3, 3); + + *(VecType *)&Bp[144] = D0; + *(VecType *)&Bp[176] = D1; + *(VecType *)&Bp[208] = D2; + *(VecType *)&Bp[240] = D3; + + y -= 32; + b += 32; + Bp += 8 * 32; + } + while (y >= 16) { + // First 4th row and 16 Columns + VecType b1 = *reinterpret_cast(&b[0]); + VecType b2 = *reinterpret_cast(&b[ldb]); + VecType b3 = *reinterpret_cast(&b[ldb * 2]); + VecType b4 = *reinterpret_cast(&b[ldb * 3]); + + VecType vec_even12 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(b1), + reinterpret_cast<__vector int>(b2))); + VecType vec_even34 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(b3), + reinterpret_cast<__vector int>(b4))); + VecType vec_odd12 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(b1), + reinterpret_cast<__vector int>(b2))); + VecType vec_odd34 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(b3), + reinterpret_cast<__vector int>(b4))); + + VecType vx_row1 = vec_xxpermdi(vec_even12, vec_even34, 0); + VecType vx_row3 = vec_xxpermdi(vec_odd12, vec_odd34, 0); + VecType vx_row5 = vec_xxpermdi(vec_even12, vec_even34, 3); + VecType vx_row7 = vec_xxpermdi(vec_odd12, vec_odd34, 3); + + *reinterpret_cast(&Bp[0]) = vx_row1; + *reinterpret_cast(&Bp[32]) = vx_row3; + *reinterpret_cast(&Bp[64]) = vx_row5; + *reinterpret_cast(&Bp[96]) = vx_row7; + + // Second 4th Row and 16 Columns + b1 = *reinterpret_cast(&b[ldb * 4]); + b2 = *reinterpret_cast(&b[ldb * 5]); + b3 = *reinterpret_cast(&b[ldb * 6]); + b4 = *reinterpret_cast(&b[ldb * 7]); + + vec_even12 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(b1), + reinterpret_cast<__vector int>(b2))); + vec_even34 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(b3), + reinterpret_cast<__vector int>(b4))); + vec_odd12 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(b1), + reinterpret_cast<__vector int>(b2))); + vec_odd34 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(b3), + reinterpret_cast<__vector int>(b4))); + + VecType vx_row2 = vec_xxpermdi(vec_even12, vec_even34, 0); + VecType vx_row4 = vec_xxpermdi(vec_odd12, vec_odd34, 0); + VecType vx_row6 = vec_xxpermdi(vec_even12, vec_even34, 3); + VecType vx_row8 = vec_xxpermdi(vec_odd12, vec_odd34, 3); + + *reinterpret_cast(&Bp[16]) = vx_row2; + *reinterpret_cast(&Bp[48]) = vx_row4; + *reinterpret_cast(&Bp[80]) = vx_row6; + *reinterpret_cast(&Bp[112]) = vx_row8; + + b += 16; + Bp += 128; + y -= 16; + } + while (y >= 8) { + VecType V0, V1, V2, V3; + VecType D0, D1, D2, D3; + VecType swizA = { + 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27}; + VecType swizB = { + 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31}; + *(signed long long *)&V0[0] = *(signed long long *)&b[ldb * 0]; + *(signed long long *)&V0[8] = *(signed long long *)&b[ldb * 1]; + *(signed long long *)&V1[0] = *(signed long long *)&b[ldb * 2]; + *(signed long long *)&V1[8] = *(signed long long *)&b[ldb * 3]; + *(signed long long *)&V2[0] = *(signed long long *)&b[ldb * 4]; + *(signed long long *)&V2[8] = *(signed long long *)&b[ldb * 5]; + *(signed long long *)&V3[0] = *(signed long long *)&b[ldb * 6]; + *(signed long long *)&V3[8] = *(signed long long *)&b[ldb * 7]; + + D0 = vec_perm(V0, V1, swizA); + D1 = vec_perm(V2, V3, swizA); + D2 = vec_perm(V0, V1, swizB); + D3 = vec_perm(V2, V3, swizB); + + *(VecType *)&Bp[0] = D0; + *(VecType *)&Bp[16] = D1; + *(VecType *)&Bp[32] = D2; + *(VecType *)&Bp[48] = D3; + + Bp += 64; + b += 8; + y -= 8; + } + while (y >= 4) { + int a1 = *reinterpret_cast(&b[0]); + int a2 = *reinterpret_cast(&b[ldb * 1]); + int a3 = *reinterpret_cast(&b[ldb * 2]); + int a4 = *reinterpret_cast(&b[ldb * 3]); + __vector int vec_a = {a1, a2, a3, a4}; + VecType vec_row1 = reinterpret_cast(vec_a); + *reinterpret_cast(&Bp[0]) = vec_row1; + + a1 = *reinterpret_cast(&b[ldb * 4]); + a2 = *reinterpret_cast(&b[ldb * 5]); + a3 = *reinterpret_cast(&b[ldb * 6]); + a4 = *reinterpret_cast(&b[ldb * 7]); + __vector int vec_a1 = {a1, a2, a3, a4}; + vec_row1 = reinterpret_cast(vec_a1); + *reinterpret_cast(&Bp[16]) = vec_row1; + Bp += 32; + y -= 4; + b += 4; + } + if (y >= 1 && y <= 3) { + VecType vec_tail1 + = reinterpret_cast(vec_splats(uint8_t(0))); + VecType vec_tail2 + = reinterpret_cast(vec_splats(uint8_t(0))); + + vec_tail1[0] = b[0]; + vec_tail1[4] = b[ldb]; + vec_tail1[8] = b[ldb * 2]; + vec_tail1[12] = b[ldb * 3]; + + vec_tail2[0] = b[ldb * 4]; + vec_tail2[4] = b[ldb * 5]; + vec_tail2[8] = b[ldb * 6]; + vec_tail2[12] = b[ldb * 7]; + + if (y == 3) { + vec_tail1[1] = b[1]; + vec_tail1[5] = b[ldb + 1]; + vec_tail1[9] = b[ldb * 2 + 1]; + vec_tail1[13] = b[ldb * 3 + 1]; + + vec_tail2[1] = b[ldb * 4 + 1]; + vec_tail2[5] = b[ldb * 5 + 1]; + vec_tail2[9] = b[ldb * 6 + 1]; + vec_tail2[13] = b[ldb * 7 + 1]; + + vec_tail1[2] = b[2]; + vec_tail1[6] = b[ldb + 2]; + vec_tail1[10] = b[ldb * 2 + 2]; + vec_tail1[14] = b[ldb * 3 + 2]; + + vec_tail2[2] = b[ldb * 4 + 2]; + vec_tail2[6] = b[ldb * 5 + 2]; + vec_tail2[10] = b[ldb * 6 + 2]; + vec_tail2[14] = b[ldb * 7 + 2]; + } + if (y == 2) { + vec_tail1[1] = b[1]; + vec_tail1[5] = b[ldb + 1]; + vec_tail1[9] = b[ldb * 2 + 1]; + vec_tail1[13] = b[ldb * 3 + 1]; + + vec_tail2[1] = b[ldb * 4 + 1]; + vec_tail2[5] = b[ldb * 5 + 1]; + vec_tail2[9] = b[ldb * 6 + 1]; + vec_tail2[13] = b[ldb * 7 + 1]; + } + *reinterpret_cast(&Bp[0]) = vec_tail1; + *reinterpret_cast(&Bp[16]) = vec_tail2; + Bp += 32; + } + N_dim -= 8; + B += 8 * ldb; + } + if (N_dim >= 4 && N_dim < 8) { + uint8_t *b = const_cast(B); + size_t y = K_dim; + int tail_N = N_dim - 4; + while (y >= 32) { + + __vector_pair row1, row2, row3, row4, row5, row6, row7, row8; + VecType r1[2] = {0}, r2[2] = {0}, r3[2] = {0}, r4[2] = {0}, + r5[2] = {0}, r6[2] = {0}, r7[2] = {0}, r8[2] = {0}; + VecType V0, V1, V2, V3; + VecType D0, D1, D2, D3; + VecType swizA = { + 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27}; + VecType swizB = { + 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31}; + row1 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[0])); + row2 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[ldb])); + row3 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[ldb * 2])); + row4 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[ldb * 3])); + + __builtin_vsx_disassemble_pair(r1, &row1); + __builtin_vsx_disassemble_pair(r2, &row2); + __builtin_vsx_disassemble_pair(r3, &row3); + __builtin_vsx_disassemble_pair(r4, &row4); + + // First 4 Rows and First 16 columns + V0 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r1[0]), + reinterpret_cast<__vector int>(r2[0]))); + V1 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r3[0]), + reinterpret_cast<__vector int>(r4[0]))); + V2 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r1[0]), + reinterpret_cast<__vector int>(r2[0]))); + V3 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r3[0]), + reinterpret_cast<__vector int>(r4[0]))); + + D0 = vec_xxpermdi(V0, V1, 0); + D1 = vec_xxpermdi(V2, V3, 0); + D2 = vec_xxpermdi(V0, V1, 3); + D3 = vec_xxpermdi(V2, V3, 3); + + *(VecType *)&Bp[0] = D0; + if (tail_N == 0) { + *(VecType *)&Bp[16] = D1; + *(VecType *)&Bp[32] = D2; + *(VecType *)&Bp[48] = D3; + } + if (tail_N >= 1) { + *(VecType *)&Bp[32] = D1; + *(VecType *)&Bp[64] = D2; + *(VecType *)&Bp[96] = D3; + } + + V0 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r1[1]), + reinterpret_cast<__vector int>(r2[1]))); + V1 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r3[1]), + reinterpret_cast<__vector int>(r4[1]))); + V2 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r1[1]), + reinterpret_cast<__vector int>(r2[1]))); + V3 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r3[1]), + reinterpret_cast<__vector int>(r4[1]))); + + D0 = vec_xxpermdi(V0, V1, 0); + D1 = vec_xxpermdi(V2, V3, 0); + D2 = vec_xxpermdi(V0, V1, 3); + D3 = vec_xxpermdi(V2, V3, 3); + + if (tail_N == 0) { + *(VecType *)&Bp[64] = D0; + *(VecType *)&Bp[80] = D1; + *(VecType *)&Bp[96] = D2; + *(VecType *)&Bp[112] = D3; + } + if (tail_N >= 1) { + *(VecType *)&Bp[128] = D0; + *(VecType *)&Bp[160] = D1; + *(VecType *)&Bp[192] = D2; + *(VecType *)&Bp[224] = D3; + } + if (tail_N >= 1) { + row5 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[ldb * 4])); + __builtin_vsx_disassemble_pair(r5, &row5); + if (tail_N == 3) { + row6 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[ldb * 5])); + row7 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[ldb * 6])); + __builtin_vsx_disassemble_pair(r6, &row6); + __builtin_vsx_disassemble_pair(r7, &row7); + } + if (tail_N == 2) { + row6 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[ldb * 5])); + __builtin_vsx_disassemble_pair(r6, &row6); + } + + V0 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r5[0]), + reinterpret_cast<__vector int>(r6[0]))); + V1 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r7[0]), + reinterpret_cast<__vector int>(r8[0]))); + V2 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r5[0]), + reinterpret_cast<__vector int>(r6[0]))); + V3 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r7[0]), + reinterpret_cast<__vector int>(r8[0]))); + + D0 = vec_xxpermdi(V0, V1, 0); + D1 = vec_xxpermdi(V2, V3, 0); + D2 = vec_xxpermdi(V0, V1, 3); + D3 = vec_xxpermdi(V2, V3, 3); + + *(VecType *)&Bp[16] = D0; + *(VecType *)&Bp[48] = D1; + *(VecType *)&Bp[80] = D2; + *(VecType *)&Bp[112] = D3; + + //Next 16 Columns + V0 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r5[1]), + reinterpret_cast<__vector int>(r6[1]))); + V1 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r7[1]), + reinterpret_cast<__vector int>(r8[1]))); + V2 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r5[1]), + reinterpret_cast<__vector int>(r6[1]))); + V3 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r7[1]), + reinterpret_cast<__vector int>(r8[1]))); + + D0 = vec_xxpermdi(V0, V1, 0); + D1 = vec_xxpermdi(V2, V3, 0); + D2 = vec_xxpermdi(V0, V1, 3); + D3 = vec_xxpermdi(V2, V3, 3); + + *(VecType *)&Bp[144] = D0; + *(VecType *)&Bp[176] = D1; + *(VecType *)&Bp[208] = D2; + *(VecType *)&Bp[240] = D3; + } + if (tail_N == 0) { + Bp += 128; + } else { + Bp += 256; + } + b += 32; + y -= 32; + } + + while (y >= 16) { + VecType b1 = *reinterpret_cast(&b[0]); + VecType b2 = *reinterpret_cast(&b[ldb]); + VecType b3 = *reinterpret_cast(&b[ldb * 2]); + VecType b4 = *reinterpret_cast(&b[ldb * 3]); + + VecType vec_even12 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(b1), + reinterpret_cast<__vector int>(b2))); + VecType vec_even34 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(b3), + reinterpret_cast<__vector int>(b4))); + VecType vec_odd12 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(b1), + reinterpret_cast<__vector int>(b2))); + VecType vec_odd34 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(b3), + reinterpret_cast<__vector int>(b4))); + + VecType vx_row1 = vec_xxpermdi(vec_even12, vec_even34, 0); + VecType vx_row3 = vec_xxpermdi(vec_odd12, vec_odd34, 0); + VecType vx_row5 = vec_xxpermdi(vec_even12, vec_even34, 3); + VecType vx_row7 = vec_xxpermdi(vec_odd12, vec_odd34, 3); + + *reinterpret_cast(&Bp[0]) = vx_row1; + if (tail_N == 0) { + *reinterpret_cast(&Bp[16]) = vx_row3; + *reinterpret_cast(&Bp[32]) = vx_row5; + *reinterpret_cast(&Bp[48]) = vx_row7; + } + if (tail_N >= 1) { + *reinterpret_cast(&Bp[32]) = vx_row3; + *reinterpret_cast(&Bp[64]) = vx_row5; + *reinterpret_cast(&Bp[96]) = vx_row7; + } + + if (tail_N >= 1) { + VecType b5 = {0}, b6 = {0}, b7 = {0}, b8 = {0}; + b5 = *reinterpret_cast(&b[ldb * 4]); + + if (tail_N == 3) { + b6 = *reinterpret_cast(&b[ldb * 5]); + b7 = *reinterpret_cast(&b[ldb * 6]); + } + if (tail_N == 2) { + b6 = *reinterpret_cast(&b[ldb * 5]); + } + + VecType vec_even12 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(b5), + reinterpret_cast<__vector int>(b6))); + VecType vec_even34 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(b7), + reinterpret_cast<__vector int>(b8))); + VecType vec_odd12 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(b5), + reinterpret_cast<__vector int>(b6))); + VecType vec_odd34 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(b7), + reinterpret_cast<__vector int>(b8))); + + VecType vx_row1 = vec_xxpermdi(vec_even12, vec_even34, 0); + VecType vx_row3 = vec_xxpermdi(vec_odd12, vec_odd34, 0); + VecType vx_row5 = vec_xxpermdi(vec_even12, vec_even34, 3); + VecType vx_row7 = vec_xxpermdi(vec_odd12, vec_odd34, 3); + + *reinterpret_cast(&Bp[16]) = vx_row1; + *reinterpret_cast(&Bp[48]) = vx_row3; + *reinterpret_cast(&Bp[80]) = vx_row5; + *reinterpret_cast(&Bp[112]) = vx_row7; + } + b += 16; + if (tail_N >= 1) { + Bp += 16 * 8; + } else { + Bp += 16 * 4; + } + y -= 16; + } + + while (y >= 8) { + VecType V0 = {0}, V1 = {0}, V2 = {0}, V3 = {0}; + VecType D0, D1, D2, D3; + + VecType swizA = { + 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27}; + VecType swizB = { + 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31}; + + *(signed long long *)&V0[0] = *(signed long long *)&b[ldb * 0]; + *(signed long long *)&V0[8] = *(signed long long *)&b[ldb * 1]; + *(signed long long *)&V1[0] = *(signed long long *)&b[ldb * 2]; + *(signed long long *)&V1[8] = *(signed long long *)&b[ldb * 3]; + + D0 = vec_perm(V0, V1, swizA); + D2 = vec_perm(V0, V1, swizB); + + *reinterpret_cast(&Bp[0]) = D0; + if (tail_N == 0) { + *reinterpret_cast(&Bp[16]) = D2; + } else { + *reinterpret_cast(&Bp[32]) = D2; + } + + if (tail_N >= 1) { + *(signed long long *)&V2[0] = *(signed long long *)&b[ldb * 4]; + if (tail_N == 3) { + *(signed long long *)&V2[8] + = *(signed long long *)&b[ldb * 5]; + *(signed long long *)&V3[0] + = *(signed long long *)&b[ldb * 6]; + } + if (tail_N == 2) { + *(signed long long *)&V2[8] + = *(signed long long *)&b[ldb * 5]; + } + D1 = vec_perm(V2, V3, swizA); + D3 = vec_perm(V2, V3, swizB); + + *(VecType *)&Bp[16] = D1; + *(VecType *)&Bp[48] = D3; + } + b += 8; + + if (tail_N >= 1) { + Bp += 8 * 8; + } else { + Bp += 4 * 8; + } + y -= 8; + } + + while (y >= 4) { + int b1 = *reinterpret_cast(&b[0]); + int b2 = *reinterpret_cast(&b[ldb * 1]); + int b3 = *reinterpret_cast(&b[ldb * 2]); + int b4 = *reinterpret_cast(&b[ldb * 3]); + __vector int vec_a = {b1, b2, b3, b4}; + VecType vec_row1 = reinterpret_cast(vec_a); + *(VecType *)&Bp[0] = vec_row1; + Bp += 16; + + if (tail_N >= 1) { + b2 = 0; + b3 = 0; + b4 = 0; + b1 = *reinterpret_cast(&b[ldb * 4]); + + if (tail_N == 3) { + b2 = *reinterpret_cast(&b[ldb * 5]); + b3 = *reinterpret_cast(&b[ldb * 6]); + } + if (tail_N == 2) { + b2 = *reinterpret_cast(&b[ldb * 5]); + } + __vector int vec_a1 = {b1, b2, b3, b4}; + VecType vec_row2 = reinterpret_cast(vec_a1); + *(VecType *)&Bp[0] = vec_row2; + Bp += 16; + } + y -= 4; + b += 4; + } + + if (y >= 1 && y <= 3) { + VecType vec_tail1 = {0}; + VecType vec_tail2 = {0}; + vec_tail1[0] = b[0]; + vec_tail1[4] = b[ldb]; + vec_tail1[8] = b[ldb * 2]; + vec_tail1[12] = b[ldb * 3]; + if (y == 3) { + vec_tail1[1] = b[1]; + vec_tail1[5] = b[ldb + 1]; + vec_tail1[9] = b[ldb * 2 + 1]; + vec_tail1[13] = b[ldb * 3 + 1]; + + vec_tail1[2] = b[2]; + vec_tail1[6] = b[ldb + 2]; + vec_tail1[10] = b[ldb * 2 + 2]; + vec_tail1[14] = b[ldb * 3 + 2]; + } + if (y == 2) { + vec_tail1[1] = b[1]; + vec_tail1[5] = b[ldb + 1]; + vec_tail1[9] = b[ldb * 2 + 1]; + vec_tail1[13] = b[ldb * 3 + 1]; + } + *reinterpret_cast(&Bp[0]) = vec_tail1; + Bp += 16; + + if (tail_N >= 1) { + if (tail_N == 1) { + vec_tail2[0] = b[ldb * 4]; + if (y == 3) { + vec_tail2[1] = b[ldb * 4 + 1]; + vec_tail2[2] = b[ldb * 4 + 2]; + } + if (y == 2) { vec_tail2[1] = b[ldb * 4 + 1]; } + } + if (tail_N == 2) { + vec_tail2[0] = b[ldb * 4]; + vec_tail2[4] = b[ldb * 5]; + if (y == 3) { + vec_tail2[1] = b[ldb * 4 + 1]; + vec_tail2[2] = b[ldb * 4 + 2]; + vec_tail2[5] = b[ldb * 5 + 1]; + vec_tail2[6] = b[ldb * 5 + 2]; + } + if (y == 2) { + vec_tail2[1] = b[ldb * 4 + 1]; + vec_tail2[5] = b[ldb * 5 + 1]; + } + } + if (tail_N == 3) { + vec_tail2[0] = b[ldb * 4]; + vec_tail2[4] = b[ldb * 5]; + vec_tail2[8] = b[ldb * 6]; + if (y == 3) { + vec_tail2[1] = b[ldb * 4 + 1]; + vec_tail2[2] = b[ldb * 4 + 2]; + vec_tail2[5] = b[ldb * 5 + 1]; + vec_tail2[6] = b[ldb * 5 + 2]; + vec_tail2[9] = b[ldb * 6 + 1]; + vec_tail2[10] = b[ldb * 6 + 2]; + } + if (y == 2) { + vec_tail2[1] = b[ldb * 4 + 1]; + vec_tail2[5] = b[ldb * 5 + 1]; + vec_tail2[9] = b[ldb * 6 + 1]; + } + } + *reinterpret_cast(&Bp[0]) = vec_tail2; + Bp += 16; + } + } + } + if (N_dim <= 3 && N_dim >= 1) { + + const uint8_t *b = B; + size_t y = K_dim; + while (y >= 4) { + int a1 = *reinterpret_cast(&b[0]); + int a2 = 0, a3 = 0, a4 = 0; + if (N_dim == 3) { + a2 = *reinterpret_cast(&b[ldb * 1]); + a3 = *reinterpret_cast(&b[ldb * 2]); + } + if (N_dim == 2) { + a2 = *reinterpret_cast(&b[ldb * 1]); + } + + __vector int vec_a = {a1, a2, a3, a4}; + VecType vec_row1 = reinterpret_cast(vec_a); + *(VecType *)&Bp[0] = vec_row1; + + Bp += 16; + b += 4; + y -= 4; + } + + if (y >= 1 && y <= 3) { + VecType vec_tail1 + = reinterpret_cast(vec_splats(uint8_t(0))); + + int tail_N = N_dim; + + if (tail_N >= 1) { + if (tail_N == 1) { + vec_tail1[0] = b[0]; + if (y == 3) { + vec_tail1[1] = b[1]; + vec_tail1[2] = b[2]; + } + if (y == 2) { vec_tail1[1] = b[1]; } + } + if (tail_N == 2) { + vec_tail1[0] = b[0]; + vec_tail1[4] = b[ldb]; + if (y == 3) { + vec_tail1[1] = b[1]; + vec_tail1[2] = b[2]; + + vec_tail1[5] = b[ldb * 1 + 1]; + vec_tail1[6] = b[ldb * 1 + 2]; + } + if (y == 2) { + vec_tail1[1] = b[1]; + vec_tail1[5] = b[ldb * 1 + 1]; + } + } + if (tail_N == 3) { + vec_tail1[0] = b[0]; + vec_tail1[4] = b[ldb]; + vec_tail1[8] = b[ldb * 2]; + if (y == 3) { + vec_tail1[1] = b[1]; + vec_tail1[2] = b[2]; + + vec_tail1[5] = b[ldb + 1]; + vec_tail1[6] = b[ldb + 2]; + + vec_tail1[9] = b[ldb * 2 + 1]; + vec_tail1[10] = b[ldb * 2 + 2]; + } + if (y == 2) { + vec_tail1[1] = b[1]; + + vec_tail1[5] = b[ldb + 1]; + vec_tail1[9] = b[ldb * 2 + 1]; + } + } + + *reinterpret_cast(&Bp[0]) = vec_tail1; + Bp += 16; + } + } + } + return 0; +} + +template +int pack_N8_8bit_V2_lxvp_signed(dim_t K_dim, dim_t N_dim, const BufType *B, + dim_t ldb, uint8_t *Bp, bool is_signed) { + const uint8_t BitFlipValue = (is_signed ? 0x80 : 0); + VecType vmask = reinterpret_cast(vec_splats(BitFlipValue)); + const int8_t Flip = (is_signed ? -128 : 0); + + typedef __vector unsigned char vec_t; + + while (N_dim >= 8) { + BufType *b = const_cast(B); + size_t y = K_dim; + while (y >= 32) { + __vector_pair row1, row2, row3, row4, row5, row6, row7, row8; + VecType r1[2], r2[2], r3[2], r4[2], r5[2], r6[2], r7[2], r8[2]; + VecType V0, V1, V2, V3; + VecType D0, D1, D2, D3; + + row1 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[0])); + row2 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[ldb])); + row3 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[ldb * 2])); + row4 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[ldb * 3])); + row5 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[ldb * 4])); + row6 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[ldb * 5])); + row7 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[ldb * 6])); + row8 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[ldb * 7])); + + __builtin_vsx_disassemble_pair(r1, &row1); + __builtin_vsx_disassemble_pair(r2, &row2); + __builtin_vsx_disassemble_pair(r3, &row3); + __builtin_vsx_disassemble_pair(r4, &row4); + __builtin_vsx_disassemble_pair(r5, &row5); + __builtin_vsx_disassemble_pair(r6, &row6); + __builtin_vsx_disassemble_pair(r7, &row7); + __builtin_vsx_disassemble_pair(r8, &row8); + + // First 4 Rows and First 16 columns + V0 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r1[0]), + reinterpret_cast<__vector int>(r2[0]))); + V1 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r3[0]), + reinterpret_cast<__vector int>(r4[0]))); + V2 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r1[0]), + reinterpret_cast<__vector int>(r2[0]))); + V3 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r3[0]), + reinterpret_cast<__vector int>(r4[0]))); + + D0 = vec_xxpermdi(V0, V1, 0); + D1 = vec_xxpermdi(V2, V3, 0); + D2 = vec_xxpermdi(V0, V1, 3); + D3 = vec_xxpermdi(V2, V3, 3); + + *(vec_t *)&Bp[0] = reinterpret_cast(vec_add(D0, vmask)); + *(vec_t *)&Bp[32] = reinterpret_cast(vec_add(D1, vmask)); + *(vec_t *)&Bp[64] = reinterpret_cast(vec_add(D2, vmask)); + *(vec_t *)&Bp[96] = reinterpret_cast(vec_add(D3, vmask)); + + // Next (ldb * 4) 4 Rows and First 16 columns + V0 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r5[0]), + reinterpret_cast<__vector int>(r6[0]))); + V1 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r7[0]), + reinterpret_cast<__vector int>(r8[0]))); + V2 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r5[0]), + reinterpret_cast<__vector int>(r6[0]))); + V3 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r7[0]), + reinterpret_cast<__vector int>(r8[0]))); + + D0 = vec_xxpermdi(V0, V1, 0); + D1 = vec_xxpermdi(V2, V3, 0); + D2 = vec_xxpermdi(V0, V1, 3); + D3 = vec_xxpermdi(V2, V3, 3); + + *(vec_t *)&Bp[16] = reinterpret_cast(vec_add(D0, vmask)); + *(vec_t *)&Bp[48] = reinterpret_cast(vec_add(D1, vmask)); + *(vec_t *)&Bp[80] = reinterpret_cast(vec_add(D2, vmask)); + *(vec_t *)&Bp[112] = reinterpret_cast(vec_add(D3, vmask)); + + // First 4 Rows and Second 16 columns + V0 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r1[1]), + reinterpret_cast<__vector int>(r2[1]))); + V1 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r3[1]), + reinterpret_cast<__vector int>(r4[1]))); + V2 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r1[1]), + reinterpret_cast<__vector int>(r2[1]))); + V3 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r3[1]), + reinterpret_cast<__vector int>(r4[1]))); + + D0 = vec_xxpermdi(V0, V1, 0); + D1 = vec_xxpermdi(V2, V3, 0); + D2 = vec_xxpermdi(V0, V1, 3); + D3 = vec_xxpermdi(V2, V3, 3); + + *(vec_t *)&Bp[128] = reinterpret_cast(vec_add(D0, vmask)); + *(vec_t *)&Bp[160] = reinterpret_cast(vec_add(D1, vmask)); + *(vec_t *)&Bp[192] = reinterpret_cast(vec_add(D2, vmask)); + *(vec_t *)&Bp[224] = reinterpret_cast(vec_add(D3, vmask)); + + // Next (ldb * 4) 4 Rows and Second 16 columns + V0 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r5[1]), + reinterpret_cast<__vector int>(r6[1]))); + V1 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r7[1]), + reinterpret_cast<__vector int>(r8[1]))); + V2 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r5[1]), + reinterpret_cast<__vector int>(r6[1]))); + V3 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r7[1]), + reinterpret_cast<__vector int>(r8[1]))); + + D0 = vec_xxpermdi(V0, V1, 0); + D1 = vec_xxpermdi(V2, V3, 0); + D2 = vec_xxpermdi(V0, V1, 3); + D3 = vec_xxpermdi(V2, V3, 3); + + *(vec_t *)&Bp[144] = reinterpret_cast(vec_add(D0, vmask)); + *(vec_t *)&Bp[176] = reinterpret_cast(vec_add(D1, vmask)); + *(vec_t *)&Bp[208] = reinterpret_cast(vec_add(D2, vmask)); + *(vec_t *)&Bp[240] = reinterpret_cast(vec_add(D3, vmask)); + + y -= 32; + b += 32; + Bp += 8 * 32; + } + while (y >= 16) { + // First 4th row and 16 Columns + VecType b1 = *reinterpret_cast(&b[0]); + VecType b2 = *reinterpret_cast(&b[ldb]); + VecType b3 = *reinterpret_cast(&b[ldb * 2]); + VecType b4 = *reinterpret_cast(&b[ldb * 3]); + + VecType vec_even12 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(b1), + reinterpret_cast<__vector int>(b2))); + VecType vec_even34 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(b3), + reinterpret_cast<__vector int>(b4))); + VecType vec_odd12 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(b1), + reinterpret_cast<__vector int>(b2))); + VecType vec_odd34 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(b3), + reinterpret_cast<__vector int>(b4))); + + VecType vx_row1 = vec_xxpermdi(vec_even12, vec_even34, 0); + VecType vx_row3 = vec_xxpermdi(vec_odd12, vec_odd34, 0); + VecType vx_row5 = vec_xxpermdi(vec_even12, vec_even34, 3); + VecType vx_row7 = vec_xxpermdi(vec_odd12, vec_odd34, 3); + + *reinterpret_cast(&Bp[0]) + = reinterpret_cast(vec_add(vx_row1, vmask)); + *reinterpret_cast(&Bp[32]) + = reinterpret_cast(vec_add(vx_row3, vmask)); + *reinterpret_cast(&Bp[64]) + = reinterpret_cast(vec_add(vx_row5, vmask)); + *reinterpret_cast(&Bp[96]) + = reinterpret_cast(vec_add(vx_row7, vmask)); + + // Second 4th Row and 16 Columns + b1 = *reinterpret_cast(&b[ldb * 4]); + b2 = *reinterpret_cast(&b[ldb * 5]); + b3 = *reinterpret_cast(&b[ldb * 6]); + b4 = *reinterpret_cast(&b[ldb * 7]); + + vec_even12 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(b1), + reinterpret_cast<__vector int>(b2))); + vec_even34 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(b3), + reinterpret_cast<__vector int>(b4))); + vec_odd12 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(b1), + reinterpret_cast<__vector int>(b2))); + vec_odd34 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(b3), + reinterpret_cast<__vector int>(b4))); + + VecType vx_row2 = vec_xxpermdi(vec_even12, vec_even34, 0); + VecType vx_row4 = vec_xxpermdi(vec_odd12, vec_odd34, 0); + VecType vx_row6 = vec_xxpermdi(vec_even12, vec_even34, 3); + VecType vx_row8 = vec_xxpermdi(vec_odd12, vec_odd34, 3); + + *reinterpret_cast(&Bp[16]) + = reinterpret_cast(vec_add(vx_row2, vmask)); + *reinterpret_cast(&Bp[48]) + = reinterpret_cast(vec_add(vx_row4, vmask)); + *reinterpret_cast(&Bp[80]) + = reinterpret_cast(vec_add(vx_row6, vmask)); + *reinterpret_cast(&Bp[112]) + = reinterpret_cast(vec_add(vx_row8, vmask)); + + b += 16; + Bp += 128; + y -= 16; + } + while (y >= 8) { + VecType V0, V1, V2, V3; + VecType D0, D1, D2, D3; + vec_t swizA = { + 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27}; + vec_t swizB = { + 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31}; + *(signed long long *)&V0[0] = *(signed long long *)&b[ldb * 0]; + *(signed long long *)&V0[8] = *(signed long long *)&b[ldb * 1]; + *(signed long long *)&V1[0] = *(signed long long *)&b[ldb * 2]; + *(signed long long *)&V1[8] = *(signed long long *)&b[ldb * 3]; + *(signed long long *)&V2[0] = *(signed long long *)&b[ldb * 4]; + *(signed long long *)&V2[8] = *(signed long long *)&b[ldb * 5]; + *(signed long long *)&V3[0] = *(signed long long *)&b[ldb * 6]; + *(signed long long *)&V3[8] = *(signed long long *)&b[ldb * 7]; + + D0 = vec_perm(V0, V1, swizA); + D1 = vec_perm(V2, V3, swizA); + D2 = vec_perm(V0, V1, swizB); + D3 = vec_perm(V2, V3, swizB); + + *(vec_t *)&Bp[0] = reinterpret_cast(vec_add(D0, vmask)); + *(vec_t *)&Bp[16] = reinterpret_cast(vec_add(D1, vmask)); + *(vec_t *)&Bp[32] = reinterpret_cast(vec_add(D2, vmask)); + *(vec_t *)&Bp[48] = reinterpret_cast(vec_add(D3, vmask)); + + Bp += 64; + b += 8; + y -= 8; + } + while (y >= 4) { + int a1 = *reinterpret_cast(&b[0]); + int a2 = *reinterpret_cast(&b[ldb * 1]); + int a3 = *reinterpret_cast(&b[ldb * 2]); + int a4 = *reinterpret_cast(&b[ldb * 3]); + __vector int vec_a = {a1, a2, a3, a4}; + VecType vec_row1 = reinterpret_cast(vec_a); + *reinterpret_cast(&Bp[0]) + = reinterpret_cast(vec_add(vec_row1, vmask)); + + a1 = *reinterpret_cast(&b[ldb * 4]); + a2 = *reinterpret_cast(&b[ldb * 5]); + a3 = *reinterpret_cast(&b[ldb * 6]); + a4 = *reinterpret_cast(&b[ldb * 7]); + __vector int vec_a1 = {a1, a2, a3, a4}; + vec_row1 = reinterpret_cast(vec_a1); + *reinterpret_cast(&Bp[16]) + = reinterpret_cast(vec_add(vec_row1, vmask)); + Bp += 32; + y -= 4; + b += 4; + } + if (y >= 1 && y <= 3) { + VecType vec_tail1 = reinterpret_cast(vec_splats(Flip)); + VecType vec_tail2 = reinterpret_cast(vec_splats(Flip)); + vec_tail1[0] = b[0]; + vec_tail1[4] = b[ldb]; + vec_tail1[8] = b[ldb * 2]; + vec_tail1[12] = b[ldb * 3]; + + vec_tail2[0] = b[ldb * 4]; + vec_tail2[4] = b[ldb * 5]; + vec_tail2[8] = b[ldb * 6]; + vec_tail2[12] = b[ldb * 7]; + + if (y == 3) { + vec_tail1[1] = b[1]; + vec_tail1[5] = b[ldb + 1]; + vec_tail1[9] = b[ldb * 2 + 1]; + vec_tail1[13] = b[ldb * 3 + 1]; + + vec_tail2[1] = b[ldb * 4 + 1]; + vec_tail2[5] = b[ldb * 5 + 1]; + vec_tail2[9] = b[ldb * 6 + 1]; + vec_tail2[13] = b[ldb * 7 + 1]; + + vec_tail1[2] = b[2]; + vec_tail1[6] = b[ldb + 2]; + vec_tail1[10] = b[ldb * 2 + 2]; + vec_tail1[14] = b[ldb * 3 + 2]; + + vec_tail2[2] = b[ldb * 4 + 2]; + vec_tail2[6] = b[ldb * 5 + 2]; + vec_tail2[10] = b[ldb * 6 + 2]; + vec_tail2[14] = b[ldb * 7 + 2]; + } + if (y == 2) { + vec_tail1[1] = b[1]; + vec_tail1[5] = b[ldb + 1]; + vec_tail1[9] = b[ldb * 2 + 1]; + vec_tail1[13] = b[ldb * 3 + 1]; + + vec_tail2[1] = b[ldb * 4 + 1]; + vec_tail2[5] = b[ldb * 5 + 1]; + vec_tail2[9] = b[ldb * 6 + 1]; + vec_tail2[13] = b[ldb * 7 + 1]; + } + *reinterpret_cast(&Bp[0]) + = reinterpret_cast(vec_add(vec_tail1, vmask)); + *reinterpret_cast(&Bp[16]) + = reinterpret_cast(vec_add(vec_tail2, vmask)); + Bp += 32; + } + N_dim -= 8; + B += 8 * ldb; + } + if (N_dim >= 4 && N_dim < 8) { + BufType *b = const_cast(B); + size_t y = K_dim; + int tail_N = N_dim - 4; + while (y >= 32) { + + __vector_pair row1, row2, row3, row4, row5, row6, row7; + VecType r1[2] = {0}, r2[2] = {0}, r3[2] = {0}, r4[2] = {0}, + r5[2] = {0}, r6[2] = {0}, r7[2] = {0}, r8[2] = {0}; + VecType V0, V1, V2, V3; + VecType D0, D1, D2, D3; + row1 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[0])); + row2 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[ldb])); + row3 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[ldb * 2])); + row4 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[ldb * 3])); + + __builtin_vsx_disassemble_pair(r1, &row1); + __builtin_vsx_disassemble_pair(r2, &row2); + __builtin_vsx_disassemble_pair(r3, &row3); + __builtin_vsx_disassemble_pair(r4, &row4); + + // First 4 Rows and First 16 columns + V0 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r1[0]), + reinterpret_cast<__vector int>(r2[0]))); + V1 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r3[0]), + reinterpret_cast<__vector int>(r4[0]))); + V2 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r1[0]), + reinterpret_cast<__vector int>(r2[0]))); + V3 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r3[0]), + reinterpret_cast<__vector int>(r4[0]))); + + D0 = vec_xxpermdi(V0, V1, 0); + D1 = vec_xxpermdi(V2, V3, 0); + D2 = vec_xxpermdi(V0, V1, 3); + D3 = vec_xxpermdi(V2, V3, 3); + + *(vec_t *)&Bp[0] = reinterpret_cast(vec_add(D0, vmask)); + if (tail_N == 0) { + *(vec_t *)&Bp[16] = reinterpret_cast(vec_add(D1, vmask)); + *(vec_t *)&Bp[32] = reinterpret_cast(vec_add(D2, vmask)); + *(vec_t *)&Bp[48] = reinterpret_cast(vec_add(D3, vmask)); + } + if (tail_N >= 1) { + *(vec_t *)&Bp[32] = reinterpret_cast(vec_add(D1, vmask)); + *(vec_t *)&Bp[64] = reinterpret_cast(vec_add(D2, vmask)); + *(vec_t *)&Bp[96] = reinterpret_cast(vec_add(D3, vmask)); + } + + V0 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r1[1]), + reinterpret_cast<__vector int>(r2[1]))); + V1 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r3[1]), + reinterpret_cast<__vector int>(r4[1]))); + V2 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r1[1]), + reinterpret_cast<__vector int>(r2[1]))); + V3 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r3[1]), + reinterpret_cast<__vector int>(r4[1]))); + + D0 = vec_xxpermdi(V0, V1, 0); + D1 = vec_xxpermdi(V2, V3, 0); + D2 = vec_xxpermdi(V0, V1, 3); + D3 = vec_xxpermdi(V2, V3, 3); + + if (tail_N == 0) { + *(vec_t *)&Bp[64] = reinterpret_cast(vec_add(D0, vmask)); + *(vec_t *)&Bp[80] = reinterpret_cast(vec_add(D1, vmask)); + *(vec_t *)&Bp[96] = reinterpret_cast(vec_add(D2, vmask)); + *(vec_t *)&Bp[112] + = reinterpret_cast(vec_add(D3, vmask)); + } + if (tail_N >= 1) { + *(vec_t *)&Bp[128] + = reinterpret_cast(vec_add(D0, vmask)); + *(vec_t *)&Bp[160] + = reinterpret_cast(vec_add(D1, vmask)); + *(vec_t *)&Bp[192] + = reinterpret_cast(vec_add(D2, vmask)); + *(vec_t *)&Bp[224] + = reinterpret_cast(vec_add(D3, vmask)); + } + if (tail_N >= 1) { + row5 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[ldb * 4])); + __builtin_vsx_disassemble_pair(r5, &row5); + if (tail_N == 3) { + row6 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[ldb * 5])); + row7 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[ldb * 6])); + __builtin_vsx_disassemble_pair(r6, &row6); + __builtin_vsx_disassemble_pair(r7, &row7); + } + if (tail_N == 2) { + row6 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[ldb * 5])); + __builtin_vsx_disassemble_pair(r6, &row6); + } + + V0 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r5[0]), + reinterpret_cast<__vector int>(r6[0]))); + V1 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r7[0]), + reinterpret_cast<__vector int>(r8[0]))); + V2 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r5[0]), + reinterpret_cast<__vector int>(r6[0]))); + V3 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r7[0]), + reinterpret_cast<__vector int>(r8[0]))); + + D0 = vec_xxpermdi(V0, V1, 0); + D1 = vec_xxpermdi(V2, V3, 0); + D2 = vec_xxpermdi(V0, V1, 3); + D3 = vec_xxpermdi(V2, V3, 3); + + *(vec_t *)&Bp[16] = reinterpret_cast(vec_add(D0, vmask)); + *(vec_t *)&Bp[48] = reinterpret_cast(vec_add(D1, vmask)); + *(vec_t *)&Bp[80] = reinterpret_cast(vec_add(D2, vmask)); + *(vec_t *)&Bp[112] + = reinterpret_cast(vec_add(D3, vmask)); + + //Next 16 Columns + V0 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r5[1]), + reinterpret_cast<__vector int>(r6[1]))); + V1 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r7[1]), + reinterpret_cast<__vector int>(r8[1]))); + V2 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r5[1]), + reinterpret_cast<__vector int>(r6[1]))); + V3 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r7[1]), + reinterpret_cast<__vector int>(r8[1]))); + + D0 = vec_xxpermdi(V0, V1, 0); + D1 = vec_xxpermdi(V2, V3, 0); + D2 = vec_xxpermdi(V0, V1, 3); + D3 = vec_xxpermdi(V2, V3, 3); + + *(vec_t *)&Bp[144] + = reinterpret_cast(vec_add(D0, vmask)); + *(vec_t *)&Bp[176] + = reinterpret_cast(vec_add(D1, vmask)); + *(vec_t *)&Bp[208] + = reinterpret_cast(vec_add(D2, vmask)); + *(vec_t *)&Bp[240] + = reinterpret_cast(vec_add(D3, vmask)); + } + if (tail_N == 0) { + Bp += 128; + } else { + Bp += 256; + } + b += 32; + y -= 32; + } + + while (y >= 16) { + VecType b1 = *reinterpret_cast(&b[0]); + VecType b2 = *reinterpret_cast(&b[ldb]); + VecType b3 = *reinterpret_cast(&b[ldb * 2]); + VecType b4 = *reinterpret_cast(&b[ldb * 3]); + + VecType vec_even12 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(b1), + reinterpret_cast<__vector int>(b2))); + VecType vec_even34 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(b3), + reinterpret_cast<__vector int>(b4))); + VecType vec_odd12 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(b1), + reinterpret_cast<__vector int>(b2))); + VecType vec_odd34 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(b3), + reinterpret_cast<__vector int>(b4))); + + VecType vx_row1 = vec_xxpermdi(vec_even12, vec_even34, 0); + VecType vx_row3 = vec_xxpermdi(vec_odd12, vec_odd34, 0); + VecType vx_row5 = vec_xxpermdi(vec_even12, vec_even34, 3); + VecType vx_row7 = vec_xxpermdi(vec_odd12, vec_odd34, 3); + + *reinterpret_cast(&Bp[0]) + = reinterpret_cast(vec_add(vx_row1, vmask)); + if (tail_N == 0) { + *reinterpret_cast(&Bp[16]) + = reinterpret_cast(vec_add(vx_row3, vmask)); + *reinterpret_cast(&Bp[32]) + = reinterpret_cast(vec_add(vx_row5, vmask)); + *reinterpret_cast(&Bp[48]) + = reinterpret_cast(vec_add(vx_row7, vmask)); + } + if (tail_N >= 1) { + *reinterpret_cast(&Bp[32]) + = reinterpret_cast(vec_add(vx_row3, vmask)); + *reinterpret_cast(&Bp[64]) + = reinterpret_cast(vec_add(vx_row5, vmask)); + *reinterpret_cast(&Bp[96]) + = reinterpret_cast(vec_add(vx_row7, vmask)); + } + + if (tail_N >= 1) { + VecType b5 = reinterpret_cast(vec_splats(Flip)); + VecType b6 = reinterpret_cast(vec_splats(Flip)); + VecType b7 = reinterpret_cast(vec_splats(Flip)); + VecType b8 = reinterpret_cast(vec_splats(Flip)); + b5 = *reinterpret_cast(&b[ldb * 4]); + + if (tail_N == 3) { + b6 = *reinterpret_cast(&b[ldb * 5]); + b7 = *reinterpret_cast(&b[ldb * 6]); + } + if (tail_N == 2) { + b6 = *reinterpret_cast(&b[ldb * 5]); + } + + VecType vec_even12 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(b5), + reinterpret_cast<__vector int>(b6))); + VecType vec_even34 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(b7), + reinterpret_cast<__vector int>(b8))); + VecType vec_odd12 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(b5), + reinterpret_cast<__vector int>(b6))); + VecType vec_odd34 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(b7), + reinterpret_cast<__vector int>(b8))); + + VecType vx_row1 = vec_xxpermdi(vec_even12, vec_even34, 0); + VecType vx_row3 = vec_xxpermdi(vec_odd12, vec_odd34, 0); + VecType vx_row5 = vec_xxpermdi(vec_even12, vec_even34, 3); + VecType vx_row7 = vec_xxpermdi(vec_odd12, vec_odd34, 3); + + *reinterpret_cast(&Bp[16]) + = reinterpret_cast(vec_add(vx_row1, vmask)); + *reinterpret_cast(&Bp[48]) + = reinterpret_cast(vec_add(vx_row3, vmask)); + *reinterpret_cast(&Bp[80]) + = reinterpret_cast(vec_add(vx_row5, vmask)); + *reinterpret_cast(&Bp[112]) + = reinterpret_cast(vec_add(vx_row7, vmask)); + } + b += 16; + if (tail_N >= 1) { + Bp += 16 * 8; + } else { + Bp += 16 * 4; + } + y -= 16; + } + + while (y >= 8) { + VecType V0 = reinterpret_cast(vec_splats(Flip)); + VecType V1 = reinterpret_cast(vec_splats(Flip)); + VecType V2 = reinterpret_cast(vec_splats(Flip)); + VecType V3 = reinterpret_cast(vec_splats(Flip)); + VecType D0, D1, D2, D3; + + vec_t swizA = { + 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27}; + vec_t swizB = { + 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31}; + + *(signed long long *)&V0[0] = *(signed long long *)&b[ldb * 0]; + *(signed long long *)&V0[8] = *(signed long long *)&b[ldb * 1]; + *(signed long long *)&V1[0] = *(signed long long *)&b[ldb * 2]; + *(signed long long *)&V1[8] = *(signed long long *)&b[ldb * 3]; + + D0 = vec_perm(V0, V1, swizA); + D2 = vec_perm(V0, V1, swizB); + + *reinterpret_cast(&Bp[0]) + = reinterpret_cast(vec_add(D0, vmask)); + if (tail_N == 0) { + *reinterpret_cast(&Bp[16]) + = reinterpret_cast(vec_add(D2, vmask)); + } else { + *reinterpret_cast(&Bp[32]) + = reinterpret_cast(vec_add(D2, vmask)); + } + + if (tail_N >= 1) { + *(signed long long *)&V2[0] = *(signed long long *)&b[ldb * 4]; + if (tail_N == 3) { + *(signed long long *)&V2[8] + = *(signed long long *)&b[ldb * 5]; + *(signed long long *)&V3[0] + = *(signed long long *)&b[ldb * 6]; + } + if (tail_N == 2) { + *(signed long long *)&V2[8] + = *(signed long long *)&b[ldb * 5]; + } + D1 = vec_perm(V2, V3, swizA); + D3 = vec_perm(V2, V3, swizB); + + *(vec_t *)&Bp[16] = reinterpret_cast(vec_add(D1, vmask)); + *(vec_t *)&Bp[48] = reinterpret_cast(vec_add(D3, vmask)); + } + b += 8; + + if (tail_N >= 1) { + Bp += 8 * 8; + } else { + Bp += 4 * 8; + } + y -= 8; + } + + while (y >= 4) { + int b1 = *reinterpret_cast(&b[0]); + int b2 = *reinterpret_cast(&b[ldb * 1]); + int b3 = *reinterpret_cast(&b[ldb * 2]); + int b4 = *reinterpret_cast(&b[ldb * 3]); + __vector int vec_a = {b1, b2, b3, b4}; + VecType vec_row1 = reinterpret_cast(vec_a); + *(vec_t *)&Bp[0] + = reinterpret_cast(vec_add(vec_row1, vmask)); + Bp += 16; + + if (tail_N >= 1) { + int value = (Flip & 0xFF) | ((Flip & 0xFF) << 8) + | ((Flip & 0xFF) << 16) | ((Flip & 0xFF) << 24); + b2 = value; + b3 = value; + b4 = value; + b1 = *reinterpret_cast(&b[ldb * 4]); + + if (tail_N == 3) { + b2 = *reinterpret_cast(&b[ldb * 5]); + b3 = *reinterpret_cast(&b[ldb * 6]); + } + if (tail_N == 2) { + b2 = *reinterpret_cast(&b[ldb * 5]); + } + __vector int vec_a1 = {b1, b2, b3, b4}; + VecType vec_row2 = reinterpret_cast(vec_a1); + *(vec_t *)&Bp[0] + = reinterpret_cast(vec_add(vec_row2, vmask)); + Bp += 16; + } + y -= 4; + b += 4; + } + + if (y >= 1 && y <= 3) { + VecType vec_tail1 = reinterpret_cast(vec_splats(Flip)); + VecType vec_tail2 = reinterpret_cast(vec_splats(Flip)); + vec_tail1[0] = b[0]; + vec_tail1[4] = b[ldb]; + vec_tail1[8] = b[ldb * 2]; + vec_tail1[12] = b[ldb * 3]; + if (y == 3) { + vec_tail1[1] = b[1]; + vec_tail1[5] = b[ldb + 1]; + vec_tail1[9] = b[ldb * 2 + 1]; + vec_tail1[13] = b[ldb * 3 + 1]; + + vec_tail1[2] = b[2]; + vec_tail1[6] = b[ldb + 2]; + vec_tail1[10] = b[ldb * 2 + 2]; + vec_tail1[14] = b[ldb * 3 + 2]; + } + if (y == 2) { + vec_tail1[1] = b[1]; + vec_tail1[5] = b[ldb + 1]; + vec_tail1[9] = b[ldb * 2 + 1]; + vec_tail1[13] = b[ldb * 3 + 1]; + } + *reinterpret_cast(&Bp[0]) + = reinterpret_cast(vec_add(vec_tail1, vmask)); + Bp += 16; + + if (tail_N >= 1) { + if (tail_N == 1) { + vec_tail2[0] = b[ldb * 4]; + if (y == 3) { + vec_tail2[1] = b[ldb * 4 + 1]; + vec_tail2[2] = b[ldb * 4 + 2]; + } + if (y == 2) { vec_tail2[1] = b[ldb * 4 + 1]; } + } + if (tail_N == 2) { + vec_tail2[0] = b[ldb * 4]; + vec_tail2[4] = b[ldb * 5]; + if (y == 3) { + vec_tail2[1] = b[ldb * 4 + 1]; + vec_tail2[2] = b[ldb * 4 + 2]; + vec_tail2[5] = b[ldb * 5 + 1]; + vec_tail2[6] = b[ldb * 5 + 2]; + } + if (y == 2) { + vec_tail2[1] = b[ldb * 4 + 1]; + vec_tail2[5] = b[ldb * 5 + 1]; + } + } + if (tail_N == 3) { + vec_tail2[0] = b[ldb * 4]; + vec_tail2[4] = b[ldb * 5]; + vec_tail2[8] = b[ldb * 6]; + if (y == 3) { + vec_tail2[1] = b[ldb * 4 + 1]; + vec_tail2[2] = b[ldb * 4 + 2]; + vec_tail2[5] = b[ldb * 5 + 1]; + vec_tail2[6] = b[ldb * 5 + 2]; + vec_tail2[9] = b[ldb * 6 + 1]; + vec_tail2[10] = b[ldb * 6 + 2]; + } + if (y == 2) { + vec_tail2[1] = b[ldb * 4 + 1]; + vec_tail2[5] = b[ldb * 5 + 1]; + vec_tail2[9] = b[ldb * 6 + 1]; + } + } + *reinterpret_cast(&Bp[0]) + = reinterpret_cast(vec_add(vec_tail2, vmask)); + Bp += 16; + } + } + } + if (N_dim <= 3 && N_dim >= 1) { + + const BufType *b = B; + size_t y = K_dim; + while (y >= 4) { + int value = (Flip & 0xFF) | ((Flip & 0xFF) << 8) + | ((Flip & 0xFF) << 16) | ((Flip & 0xFF) << 24); + int a1 = *reinterpret_cast(&b[0]); + int a2 = value, a3 = value, a4 = value; + if (N_dim == 3) { + a2 = *reinterpret_cast(&b[ldb * 1]); + a3 = *reinterpret_cast(&b[ldb * 2]); + } + if (N_dim == 2) { + a2 = *reinterpret_cast(&b[ldb * 1]); + } + + __vector int vec_a = {a1, a2, a3, a4}; + VecType vec_row1 = reinterpret_cast(vec_a); + *(vec_t *)&Bp[0] + = reinterpret_cast(vec_add(vec_row1, vmask)); + + Bp += 16; + b += 4; + y -= 4; + } + + if (y >= 1 && y <= 3) { + VecType vec_tail1 = reinterpret_cast(vec_splats(Flip)); + + int tail_N = N_dim; + + if (tail_N >= 1) { + if (tail_N == 1) { + vec_tail1[0] = b[0]; + if (y == 3) { + vec_tail1[1] = b[1]; + vec_tail1[2] = b[2]; + } + if (y == 2) { vec_tail1[1] = b[1]; } + } + if (tail_N == 2) { + vec_tail1[0] = b[0]; + vec_tail1[4] = b[ldb]; + if (y == 3) { + vec_tail1[1] = b[1]; + vec_tail1[2] = b[2]; + + vec_tail1[5] = b[ldb * 1 + 1]; + vec_tail1[6] = b[ldb * 1 + 2]; + } + if (y == 2) { + vec_tail1[1] = b[1]; + vec_tail1[5] = b[ldb * 1 + 1]; + } + } + if (tail_N == 3) { + vec_tail1[0] = b[0]; + vec_tail1[4] = b[ldb]; + vec_tail1[8] = b[ldb * 2]; + if (y == 3) { + vec_tail1[1] = b[1]; + vec_tail1[2] = b[2]; + + vec_tail1[5] = b[ldb + 1]; + vec_tail1[6] = b[ldb + 2]; + + vec_tail1[9] = b[ldb * 2 + 1]; + vec_tail1[10] = b[ldb * 2 + 2]; + } + if (y == 2) { + vec_tail1[1] = b[1]; + + vec_tail1[5] = b[ldb + 1]; + vec_tail1[9] = b[ldb * 2 + 1]; + } + } + + *reinterpret_cast(&Bp[0]) + = reinterpret_cast(vec_add(vec_tail1, vmask)); + Bp += 16; + } + } + } + return 0; +} + +template +inline int packB_N8bit(dim_t K_dim, dim_t N_dim, const b_type *B, dim_t ldb, + uint8_t *Bp, bool is_signed) { + if (is_signed) { + pack_N8_8bit_V2_lxvp_signed<__vector signed char, b_type>( + K_dim, N_dim, B, ldb, Bp, true); + } else { + pack_N8_8bit_V2_lxvp_signed<__vector unsigned char, b_type>( + K_dim, N_dim, B, ldb, Bp, false); + } + return 0; +} + +template +int pack_N8_8bit_V2( + dim_t K_dim, dim_t N_dim, const uint8_t *B, dim_t ldb, uint8_t *Bp) { + int K_block = (K_dim + 3) & (~3); + int N_block = (N_dim + 3) & (~3); + + while (N_dim >= 8) { + const uint8_t *b = B; + size_t y = K_dim; + + while (y >= 8) { + VecType V0, V1, V2, V3; + VecType D0, D1, D2, D3; + VecType swizA = { + 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27}; + VecType swizB = { + 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31}; + + *(signed long long *)&V0[0] = *(signed long long *)&b[ldb * 0]; + *(signed long long *)&V0[8] = *(signed long long *)&b[ldb * 1]; + *(signed long long *)&V1[0] = *(signed long long *)&b[ldb * 2]; + *(signed long long *)&V1[8] = *(signed long long *)&b[ldb * 3]; + *(signed long long *)&V2[0] = *(signed long long *)&b[ldb * 4]; + *(signed long long *)&V2[8] = *(signed long long *)&b[ldb * 5]; + *(signed long long *)&V3[0] = *(signed long long *)&b[ldb * 6]; + *(signed long long *)&V3[8] = *(signed long long *)&b[ldb * 7]; + + D0 = vec_perm(V0, V1, swizA); + D1 = vec_perm(V2, V3, swizA); + D2 = vec_perm(V0, V1, swizB); + D3 = vec_perm(V2, V3, swizB); + + *(VecType *)&Bp[0] = D0; + *(VecType *)&Bp[16] = D1; + *(VecType *)&Bp[32] = D2; + *(VecType *)&Bp[48] = D3; + + Bp += 64; + b += 8; + y -= 8; + } + + while (y >= 4) { + int a1 = *reinterpret_cast(&b[0]); + int a2 = *reinterpret_cast(&b[ldb * 1]); + int a3 = *reinterpret_cast(&b[ldb * 2]); + int a4 = *reinterpret_cast(&b[ldb * 3]); + __vector int vec_a = {a1, a2, a3, a4}; + VecType vec_row1 = reinterpret_cast(vec_a); + *reinterpret_cast(&Bp[0]) = vec_row1; + + a1 = *reinterpret_cast(&b[ldb * 4]); + a2 = *reinterpret_cast(&b[ldb * 5]); + a3 = *reinterpret_cast(&b[ldb * 6]); + a4 = *reinterpret_cast(&b[ldb * 7]); + __vector int vec_a1 = {a1, a2, a3, a4}; + vec_row1 = reinterpret_cast(vec_a1); + *reinterpret_cast(&Bp[16]) = vec_row1; + Bp += 32; + y -= 4; + b += 4; + } + if (y >= 1 && y <= 3) { + VecType vec_tail1 + = reinterpret_cast(vec_splats(uint8_t(0))); + VecType vec_tail2 + = reinterpret_cast(vec_splats(uint8_t(0))); + + vec_tail1[0] = b[0]; + vec_tail1[4] = b[ldb]; + vec_tail1[8] = b[ldb * 2]; + vec_tail1[12] = b[ldb * 3]; + + vec_tail2[0] = b[ldb * 4]; + vec_tail2[4] = b[ldb * 5]; + vec_tail2[8] = b[ldb * 6]; + vec_tail2[12] = b[ldb * 7]; + + if (y == 3) { + vec_tail1[1] = b[1]; + vec_tail1[5] = b[ldb + 1]; + vec_tail1[9] = b[ldb * 2 + 1]; + vec_tail1[13] = b[ldb * 3 + 1]; + + vec_tail2[1] = b[ldb * 4 + 1]; + vec_tail2[5] = b[ldb * 5 + 1]; + vec_tail2[9] = b[ldb * 6 + 1]; + vec_tail2[13] = b[ldb * 7 + 1]; + + vec_tail1[2] = b[2]; + vec_tail1[6] = b[ldb + 2]; + vec_tail1[10] = b[ldb * 2 + 2]; + vec_tail1[14] = b[ldb * 3 + 2]; + + vec_tail2[2] = b[ldb * 4 + 2]; + vec_tail2[6] = b[ldb * 5 + 2]; + vec_tail2[10] = b[ldb * 6 + 2]; + vec_tail2[14] = b[ldb * 7 + 2]; + } + if (y == 2) { + vec_tail1[1] = b[1]; + vec_tail1[5] = b[ldb + 1]; + vec_tail1[9] = b[ldb * 2 + 1]; + vec_tail1[13] = b[ldb * 3 + 1]; + + vec_tail2[1] = b[ldb * 4 + 1]; + vec_tail2[5] = b[ldb * 5 + 1]; + vec_tail2[9] = b[ldb * 6 + 1]; + vec_tail2[13] = b[ldb * 7 + 1]; + } + + *reinterpret_cast(&Bp[0]) = vec_tail1; + *reinterpret_cast(&Bp[16]) = vec_tail2; + Bp += 32; + } + + N_dim -= 8; + B += 8 * ldb; + } + + if (N_dim >= 4 && N_dim < 8) { + + const uint8_t *b = B; + size_t y = K_dim; + while (y >= 4) { + int a1 = *reinterpret_cast(&b[0]); + int a2 = *reinterpret_cast(&b[ldb * 1]); + int a3 = *reinterpret_cast(&b[ldb * 2]); + int a4 = *reinterpret_cast(&b[ldb * 3]); + __vector int vec_a = {a1, a2, a3, a4}; + VecType vec_row1 = reinterpret_cast(vec_a); + *(VecType *)&Bp[0] = vec_row1; + Bp += 16; + + int tail_N = N_dim - 4; + if (tail_N >= 1) { + a2 = 0; + a3 = 0; + a4 = 0; + a1 = *reinterpret_cast(&b[ldb * 4]); + if (tail_N == 3) { + a2 = *reinterpret_cast(&b[ldb * 5]); + a3 = *reinterpret_cast(&b[ldb * 6]); + } + if (tail_N == 2) { + a2 = *reinterpret_cast(&b[ldb * 5]); + } + __vector int vec_a1 = {a1, a2, a3, a4}; + VecType vec_row2 = reinterpret_cast(vec_a1); + *(VecType *)&Bp[0] = vec_row2; + Bp += 16; + //y -= 4; + } + y -= 4; + b += 4; + } + if (y >= 1 && y <= 3) { + VecType vec_tail1 + = reinterpret_cast(vec_splats(uint8_t(0))); + VecType vec_tail2 + = reinterpret_cast(vec_splats(uint8_t(0))); + + vec_tail1[0] = b[0]; + vec_tail1[4] = b[ldb]; + vec_tail1[8] = b[ldb * 2]; + vec_tail1[12] = b[ldb * 3]; + + if (y == 3) { + vec_tail1[1] = b[1]; + vec_tail1[5] = b[ldb + 1]; + vec_tail1[9] = b[ldb * 2 + 1]; + vec_tail1[13] = b[ldb * 3 + 1]; + + vec_tail1[2] = b[2]; + vec_tail1[6] = b[ldb + 2]; + vec_tail1[10] = b[ldb * 2 + 2]; + vec_tail1[14] = b[ldb * 3 + 2]; + } + if (y == 2) { + vec_tail1[1] = b[1]; + vec_tail1[5] = b[ldb + 1]; + vec_tail1[9] = b[ldb * 2 + 1]; + vec_tail1[13] = b[ldb * 3 + 1]; + } + *reinterpret_cast(&Bp[0]) = vec_tail1; + Bp += 16; + + int tail_N = N_dim - 4; + if (tail_N >= 1) { + if (tail_N == 1) { + vec_tail2[0] = b[ldb * 4]; + if (y == 3) { + vec_tail2[1] = b[ldb * 4 + 1]; + vec_tail2[2] = b[ldb * 4 + 2]; + } + if (y == 2) { vec_tail2[1] = b[ldb * 4 + 1]; } + } + if (tail_N == 2) { + vec_tail2[0] = b[ldb * 4]; + vec_tail2[4] = b[ldb * 5]; + if (y == 3) { + vec_tail2[1] = b[ldb * 4 + 1]; + vec_tail2[2] = b[ldb * 4 + 2]; + + vec_tail2[5] = b[ldb * 5 + 1]; + vec_tail2[6] = b[ldb * 5 + 2]; + } + if (y == 2) { + vec_tail2[1] = b[ldb * 4 + 1]; + vec_tail2[5] = b[ldb * 5 + 1]; + } + } + if (tail_N == 3) { + vec_tail2[0] = b[ldb * 4]; + vec_tail2[4] = b[ldb * 5]; + vec_tail2[8] = b[ldb * 6]; + if (y == 3) { + vec_tail2[1] = b[ldb * 4 + 1]; + vec_tail2[2] = b[ldb * 4 + 2]; + + vec_tail2[5] = b[ldb * 5 + 1]; + vec_tail2[6] = b[ldb * 5 + 2]; + + vec_tail2[9] = b[ldb * 6 + 1]; + vec_tail2[10] = b[ldb * 6 + 2]; + } + if (y == 2) { + vec_tail2[1] = b[ldb * 4 + 1]; + vec_tail2[5] = b[ldb * 5 + 1]; + vec_tail2[9] = b[ldb * 6 + 1]; + } + } + *reinterpret_cast(&Bp[0]) = vec_tail2; + Bp += 16; + } + } + + B += N_dim * ldb; + N_dim = 0; + } + + if (N_dim >= 1 && N_dim <= 3) { + const uint8_t *b = B; + size_t y = K_dim; + while (y >= 4) { + int a1 = *reinterpret_cast(&b[0]); + int a2 = 0, a3 = 0, a4 = 0; + if (N_dim == 3) { + a2 = *reinterpret_cast(&b[ldb * 1]); + a3 = *reinterpret_cast(&b[ldb * 2]); + } + if (N_dim == 2) { + a2 = *reinterpret_cast(&b[ldb * 1]); + } + + __vector int vec_a = {a1, a2, a3, a4}; + VecType vec_row1 = reinterpret_cast(vec_a); + *(VecType *)&Bp[0] = vec_row1; + + Bp += 16; + b += 4; + y -= 4; + } + + if (y >= 1 && y <= 3) { + VecType vec_tail1 + = reinterpret_cast(vec_splats(uint8_t(0))); + + int tail_N = N_dim; + + if (tail_N >= 1) { + if (tail_N == 1) { + vec_tail1[0] = b[0]; + if (y == 3) { + vec_tail1[1] = b[1]; + vec_tail1[2] = b[2]; + } + if (y == 2) { vec_tail1[1] = b[1]; } + } + if (tail_N == 2) { + vec_tail1[0] = b[0]; + vec_tail1[4] = b[ldb]; + if (y == 3) { + vec_tail1[1] = b[1]; + vec_tail1[2] = b[2]; + + vec_tail1[5] = b[ldb * 1 + 1]; + vec_tail1[6] = b[ldb * 1 + 2]; + } + if (y == 2) { + vec_tail1[1] = b[1]; + vec_tail1[5] = b[ldb * 1 + 1]; + } + } + if (tail_N == 3) { + vec_tail1[0] = b[0]; + vec_tail1[4] = b[ldb]; + vec_tail1[8] = b[ldb * 2]; + if (y == 3) { + vec_tail1[1] = b[1]; + vec_tail1[2] = b[2]; + + vec_tail1[5] = b[ldb + 1]; + vec_tail1[6] = b[ldb + 2]; + + vec_tail1[9] = b[ldb * 2 + 1]; + vec_tail1[10] = b[ldb * 2 + 2]; + } + if (y == 2) { + vec_tail1[1] = b[1]; + + vec_tail1[5] = b[ldb + 1]; + vec_tail1[9] = b[ldb * 2 + 1]; + } + } + + *reinterpret_cast(&Bp[0]) = vec_tail1; + Bp += 16; + } + } + } return 0; } -int pack_N8_8bit(dim_t k, dim_t n, const uint8_t *b, dim_t ldb, uint8_t *bp) { +template +inline int pack_N8_8bit(dim_t k, dim_t n, const T *b, dim_t ldb, T *bp) { int32_t i, j; int32_t kcell, cell, koff, noff, krows, k8, n8; int32_t n_cap = (n + 3) & ~3; @@ -832,8 +3523,8 @@ int pack_N8_8bit(dim_t k, dim_t n, const uint8_t *b, dim_t ldb, uint8_t *bp) { 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27}; vec_t swizB = { 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31}; - const uint8_t *src = &b[ldb * j + i]; - uint8_t *dest = &bp[16 * (krows * (j >> 2) + (i >> 1))]; + const T *src = &b[ldb * j + i]; + T *dest = &bp[16 * (krows * (j >> 2) + (i >> 1))]; *(signed long long *)&V0[0] = *(signed long long *)&src[ldb * 0]; *(signed long long *)&V0[8] = *(signed long long *)&src[ldb * 1]; @@ -915,8 +3606,1636 @@ int pack_N8_8bit(dim_t k, dim_t n, const uint8_t *b, dim_t ldb, uint8_t *bp) { return 0; } +template +void tailBlock16_12xK(int K_Dim, int N_Dim, const int8_t *A, int lda, + int8_t *Apacked, int32_t *row_sum_eff) { + + __vector signed int vsum1 = {0}; + __vector signed int vsum2 = {0}; + __vector signed int vsum3 = {0}; + __vector signed int vsum4 = {0}; + + if (N_Dim >= 13 && N_Dim < 16) { + const int8_t *a = A; + size_t y = K_Dim; + + while (y >= 4) { + int a1 = *reinterpret_cast(&a[0]); + int a2 = *reinterpret_cast(&a[lda * 1]); + + int a3 = *reinterpret_cast(&a[lda * 2]); + int a4 = *reinterpret_cast(&a[lda * 3]); + + __vector int vec_a = {a1, a2, a3, a4}; + VecType vec_row1 = reinterpret_cast(vec_a); + *reinterpret_cast(&Apacked[0]) = vec_row1; + vsum1 = vec_sum4s(vec_row1, vsum1); + + a1 = *reinterpret_cast(&a[lda * 4]); + a2 = *reinterpret_cast(&a[lda * 5]); + a3 = *reinterpret_cast(&a[lda * 6]); + a4 = *reinterpret_cast(&a[lda * 7]); + + __vector int vec_a2 = {a1, a2, a3, a4}; + vec_row1 = reinterpret_cast(vec_a2); + *reinterpret_cast(&Apacked[16]) = vec_row1; + vsum2 = vec_sum4s(vec_row1, vsum2); + + a1 = *reinterpret_cast(&a[lda * 8]); + a2 = *reinterpret_cast(&a[lda * 9]); + a3 = *reinterpret_cast(&a[lda * 10]); + a4 = *reinterpret_cast(&a[lda * 11]); + + __vector int vec_a3 = {a1, a2, a3, a4}; + vec_row1 = reinterpret_cast(vec_a3); + *reinterpret_cast(&Apacked[32]) = vec_row1; + vsum3 = vec_sum4s(vec_row1, vsum3); + + Apacked += 48; + + int tail_N = N_Dim - 12; + if (tail_N >= 1) { + a1 = 0, a2 = 0, a3 = 0, a4 = 0; + if (tail_N == 1) { + a1 = *reinterpret_cast(&a[lda * 12]); + } + if (tail_N == 2) { + a1 = *reinterpret_cast(&a[lda * 12]); + a2 = *reinterpret_cast(&a[lda * 13]); + } + if (tail_N == 3) { + a1 = *reinterpret_cast(&a[lda * 12]); + a2 = *reinterpret_cast(&a[lda * 13]); + a3 = *reinterpret_cast(&a[lda * 14]); + } + __vector int vec_a4 = {a1, a2, a3, a4}; + vec_row1 = reinterpret_cast(vec_a4); + *reinterpret_cast(&Apacked[0]) = vec_row1; + vsum4 = vec_sum4s(vec_row1, vsum4); + Apacked += 16; + } + y -= 4; + a += 4; + } + + if (y >= 1 && y <= 3) { + VecType vec_tail1 + = reinterpret_cast(vec_splats(uint8_t(0))); + VecType vec_tail2 + = reinterpret_cast(vec_splats(uint8_t(0))); + VecType vec_tail3 + = reinterpret_cast(vec_splats(uint8_t(0))); + VecType vec_tail4 + = reinterpret_cast(vec_splats(uint8_t(0))); + + // 1st 4 rows + vec_tail1[0] = a[0]; + vec_tail1[4] = a[lda]; + vec_tail1[8] = a[lda * 2]; + vec_tail1[12] = a[lda * 3]; + + if (y == 3) { + vec_tail1[1] = a[1]; + vec_tail1[5] = a[lda + 1]; + vec_tail1[9] = a[lda * 2 + 1]; + vec_tail1[13] = a[lda * 3 + 1]; + + vec_tail1[2] = a[2]; + vec_tail1[6] = a[lda + 2]; + vec_tail1[10] = a[lda * 2 + 2]; + vec_tail1[14] = a[lda * 3 + 2]; + } + if (y == 2) { + vec_tail1[1] = a[1]; + vec_tail1[5] = a[lda + 1]; + vec_tail1[9] = a[lda * 2 + 1]; + vec_tail1[13] = a[lda * 3 + 1]; + } + + *reinterpret_cast(&Apacked[0]) = vec_tail1; + vsum1 = vec_sum4s(vec_tail1, vsum1); + + vec_tail2[0] = a[lda * 4]; + vec_tail2[4] = a[lda * 5]; + vec_tail2[8] = a[lda * 6]; + vec_tail2[12] = a[lda * 7]; + if (y == 3) { + vec_tail2[1] = a[lda * 4 + 1]; + vec_tail2[5] = a[lda * 5 + 1]; + vec_tail2[9] = a[lda * 6 + 1]; + vec_tail2[13] = a[lda * 7 + 1]; + + vec_tail2[2] = a[lda * 4 + 2]; + vec_tail2[6] = a[lda * 5 + 2]; + vec_tail2[10] = a[lda * 6 + 2]; + vec_tail2[14] = a[lda * 7 + 2]; + } + if (y == 2) { + vec_tail2[1] = a[lda * 4 + 1]; + vec_tail2[5] = a[lda * 5 + 1]; + vec_tail2[9] = a[lda * 6 + 1]; + vec_tail2[13] = a[lda * 7 + 1]; + } + + *reinterpret_cast(&Apacked[16]) = vec_tail2; + vsum2 = vec_sum4s(vec_tail2, vsum2); + + vec_tail3[0] = a[lda * 8]; + vec_tail3[4] = a[lda * 9]; + vec_tail3[8] = a[lda * 10]; + vec_tail3[12] = a[lda * 11]; + + if (y == 3) { + vec_tail3[1] = a[lda * 8 + 1]; + vec_tail3[5] = a[lda * 9 + 1]; + vec_tail3[9] = a[lda * 10 + 1]; + vec_tail3[13] = a[lda * 11 + 1]; + + vec_tail3[2] = a[lda * 8 + 2]; + vec_tail3[6] = a[lda * 9 + 2]; + vec_tail3[10] = a[lda * 10 + 2]; + vec_tail3[14] = a[lda * 11 + 2]; + } + if (y == 2) { + vec_tail3[1] = a[lda * 8 + 1]; + vec_tail3[5] = a[lda * 9 + 1]; + vec_tail3[9] = a[lda * 10 + 1]; + vec_tail3[13] = a[lda * 11 + 1]; + } + *reinterpret_cast(&Apacked[32]) = vec_tail3; + vsum3 = vec_sum4s(vec_tail3, vsum3); + + int tail_N = N_Dim - 12; + if (tail_N >= 1) { + if (tail_N == 1) { + vec_tail4[0] = a[lda * 12]; + if (y == 3) { + vec_tail4[1] = a[lda * 12 + 1]; + vec_tail4[2] = a[lda * 12 + 2]; + } + if (y == 2) { vec_tail4[1] = a[lda * 12 + 1]; } + } + if (tail_N == 2) { + vec_tail4[0] = a[lda * 12]; + vec_tail4[4] = a[lda * 13]; + if (y == 3) { + vec_tail4[1] = a[lda * 12 + 1]; + vec_tail4[2] = a[lda * 12 + 2]; + + vec_tail4[5] = a[lda * 13 + 1]; + vec_tail4[6] = a[lda * 13 + 2]; + } + if (y == 2) { + vec_tail4[1] = a[lda * 12 + 1]; + vec_tail4[5] = a[lda * 13 + 1]; + } + } + if (tail_N == 3) { + vec_tail4[0] = a[lda * 12]; + vec_tail4[4] = a[lda * 13]; + vec_tail4[8] = a[lda * 14]; + if (y == 3) { + vec_tail4[1] = a[lda * 12 + 1]; + vec_tail4[2] = a[lda * 12 + 2]; + + vec_tail4[5] = a[lda * 13 + 1]; + vec_tail4[6] = a[lda * 13 + 2]; + + vec_tail4[9] = a[lda * 14 + 1]; + vec_tail4[10] = a[lda * 14 + 2]; + } + if (y == 2) { + vec_tail4[1] = a[lda * 12 + 1]; + vec_tail4[5] = a[lda * 13 + 1]; + vec_tail4[9] = a[lda * 14 + 1]; + } + } + *reinterpret_cast(&Apacked[48]) = vec_tail4; + vsum4 = vec_sum4s(vec_tail4, vsum4); + } + Apacked += 64; + } + + row_sum_eff[0] = vsum1[0]; + row_sum_eff[1] = vsum1[1]; + row_sum_eff[2] = vsum1[2]; + row_sum_eff[3] = vsum1[3]; + + row_sum_eff[4] = vsum2[0]; + row_sum_eff[5] = vsum2[1]; + row_sum_eff[6] = vsum2[2]; + row_sum_eff[7] = vsum2[3]; + + row_sum_eff[8] = vsum3[0]; + row_sum_eff[9] = vsum3[1]; + row_sum_eff[10] = vsum3[2]; + row_sum_eff[11] = vsum3[3]; + + row_sum_eff += 12; + + int tail_N = N_Dim - 12; + + if (tail_N == 1) { row_sum_eff[0] = vsum4[0]; } + if (tail_N == 2) { + row_sum_eff[0] = vsum4[0]; + row_sum_eff[1] = vsum4[1]; + } + if (tail_N == 3) { + row_sum_eff[0] = vsum4[0]; + row_sum_eff[1] = vsum4[1]; + row_sum_eff[2] = vsum4[2]; + } + row_sum_eff += tail_N; + } +} + +template +void tailBlock16_8xK(int K_Dim, int N_Dim, const int8_t *A, int lda, + int8_t *Apacked, int32_t *row_sum_eff) { + + while (N_Dim >= 8) { + + const int8_t *a = A; + size_t y = K_Dim; + __vector signed int vsum1 = {0}; + __vector signed int vsum2 = {0}; + + while (y >= 16) { + VecType a1 = *reinterpret_cast(&a[0]); + VecType a2 = *reinterpret_cast(&a[lda]); + VecType a3 = *reinterpret_cast(&a[lda * 2]); + VecType a4 = *reinterpret_cast(&a[lda * 3]); + + VecType vec_even12 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(a1), + reinterpret_cast<__vector int>(a2))); + VecType vec_even34 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(a3), + reinterpret_cast<__vector int>(a4))); + VecType vec_odd12 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(a1), + reinterpret_cast<__vector int>(a2))); + VecType vec_odd34 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(a3), + reinterpret_cast<__vector int>(a4))); + + VecType vx_row1 = vec_xxpermdi(vec_even12, vec_even34, 0); + VecType vx_row3 = vec_xxpermdi(vec_odd12, vec_odd34, 0); + VecType vx_row5 = vec_xxpermdi(vec_even12, vec_even34, 3); + VecType vx_row7 = vec_xxpermdi(vec_odd12, vec_odd34, 3); + + *reinterpret_cast(&Apacked[0]) = vx_row1; + *reinterpret_cast(&Apacked[32]) = vx_row3; + *reinterpret_cast(&Apacked[64]) = vx_row5; + *reinterpret_cast(&Apacked[96]) = vx_row7; + + vsum1 = vec_sum4s(vx_row1, vsum1); + vsum1 = vec_sum4s(vx_row3, vsum1); + vsum1 = vec_sum4s(vx_row5, vsum1); + vsum1 = vec_sum4s(vx_row7, vsum1); + + // 2nd 4 Columns + a1 = *reinterpret_cast(&a[lda * 4]); + a2 = *reinterpret_cast(&a[lda * 5]); + a3 = *reinterpret_cast(&a[lda * 6]); + a4 = *reinterpret_cast(&a[lda * 7]); + + vec_even12 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(a1), + reinterpret_cast<__vector int>(a2))); + vec_even34 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(a3), + reinterpret_cast<__vector int>(a4))); + vec_odd12 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(a1), + reinterpret_cast<__vector int>(a2))); + vec_odd34 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(a3), + reinterpret_cast<__vector int>(a4))); + + VecType vx_row2 = vec_xxpermdi(vec_even12, vec_even34, 0); + VecType vx_row4 = vec_xxpermdi(vec_odd12, vec_odd34, 0); + VecType vx_row6 = vec_xxpermdi(vec_even12, vec_even34, 3); + VecType vx_row8 = vec_xxpermdi(vec_odd12, vec_odd34, 3); + + *reinterpret_cast(&Apacked[16]) = vx_row2; + *reinterpret_cast(&Apacked[48]) = vx_row4; + *reinterpret_cast(&Apacked[80]) = vx_row6; + *reinterpret_cast(&Apacked[112]) = vx_row8; + + vsum2 = vec_sum4s(vx_row2, vsum2); + vsum2 = vec_sum4s(vx_row4, vsum2); + vsum2 = vec_sum4s(vx_row6, vsum2); + vsum2 = vec_sum4s(vx_row8, vsum2); + + a += 16; + Apacked += 128; + y -= 16; + } + + while (y >= 4) { + int a1 = *reinterpret_cast(&a[0]); + int a2 = *reinterpret_cast(&a[lda * 1]); + int a3 = *reinterpret_cast(&a[lda * 2]); + int a4 = *reinterpret_cast(&a[lda * 3]); + __vector int vec_a = {a1, a2, a3, a4}; + VecType vec_row1 = reinterpret_cast(vec_a); + *reinterpret_cast(&Apacked[0]) = vec_row1; + vsum1 = vec_sum4s(vec_row1, vsum1); + + // Next 4 Column + a1 = *reinterpret_cast(&a[lda * 4]); + a2 = *reinterpret_cast(&a[lda * 5]); + a3 = *reinterpret_cast(&a[lda * 6]); + a4 = *reinterpret_cast(&a[lda * 7]); + __vector int vec_a1 = {a1, a2, a3, a4}; + vec_row1 = reinterpret_cast(vec_a1); + *reinterpret_cast(&Apacked[16]) = vec_row1; + vsum2 = vec_sum4s(vec_row1, vsum2); + Apacked += 32; + a += 4; + y -= 4; + } + + if (y <= 3 && y >= 1) { + VecType vec_tail1 + = reinterpret_cast(vec_splats(uint8_t(0))); + VecType vec_tail2 + = reinterpret_cast(vec_splats(uint8_t(0))); + + vec_tail1[0] = a[0]; + vec_tail1[4] = a[lda]; + vec_tail1[8] = a[lda * 2]; + vec_tail1[12] = a[lda * 3]; + + vec_tail2[0] = a[lda * 4]; + vec_tail2[4] = a[lda * 5]; + vec_tail2[8] = a[lda * 6]; + vec_tail2[12] = a[lda * 7]; + + if (y == 3) { + vec_tail1[1] = a[1]; + vec_tail1[2] = a[2]; + vec_tail1[5] = a[lda + 1]; + vec_tail1[6] = a[lda + 2]; + vec_tail1[9] = a[lda * 2 + 1]; + vec_tail1[10] = a[lda * 2 + 2]; + vec_tail1[13] = a[lda * 3 + 1]; + vec_tail1[14] = a[lda * 3 + 2]; + + vec_tail2[1] = a[lda * 4 + 1]; + vec_tail2[2] = a[lda * 4 + 2]; + vec_tail2[5] = a[lda * 5 + 1]; + vec_tail2[6] = a[lda * 5 + 2]; + vec_tail2[9] = a[lda * 6 + 1]; + vec_tail2[10] = a[lda * 6 + 2]; + vec_tail2[13] = a[lda * 7 + 1]; + vec_tail2[14] = a[lda * 7 + 2]; + } + if (y == 2) { + vec_tail1[1] = a[1]; + vec_tail1[5] = a[lda + 1]; + vec_tail1[9] = a[lda * 2 + 1]; + vec_tail1[13] = a[lda * 3 + 1]; + + vec_tail2[1] = a[lda * 4 + 1]; + vec_tail2[5] = a[lda * 5 + 1]; + vec_tail2[9] = a[lda * 6 + 1]; + vec_tail2[13] = a[lda * 7 + 1]; + } + *reinterpret_cast(&Apacked[0]) = vec_tail1; + *reinterpret_cast(&Apacked[16]) = vec_tail2; + vsum1 = vec_sum4s(vec_tail1, vsum1); + vsum2 = vec_sum4s(vec_tail2, vsum2); + + Apacked += 32; + } + row_sum_eff[0] = vsum1[0]; + row_sum_eff[1] = vsum1[1]; + row_sum_eff[2] = vsum1[2]; + row_sum_eff[3] = vsum1[3]; + + row_sum_eff[4] = vsum2[0]; + row_sum_eff[5] = vsum2[1]; + row_sum_eff[6] = vsum2[2]; + row_sum_eff[7] = vsum2[3]; + + row_sum_eff += 8; + A += 8 * lda; + N_Dim -= 8; + } + while (N_Dim >= 4) { + const int8_t *a = A; + size_t y = K_Dim; + __vector signed int vsum1 = {0}; + __vector signed int vsum2 = {0}; + int tail_N = N_Dim - 4; + + while (y >= 4) { + int a1 = *reinterpret_cast(&a[0]); + int a2 = *reinterpret_cast(&a[lda * 1]); + int a3 = *reinterpret_cast(&a[lda * 2]); + int a4 = *reinterpret_cast(&a[lda * 3]); + + __vector int vec_a = {a1, a2, a3, a4}; + VecType vec_row = reinterpret_cast(vec_a); + *reinterpret_cast(&Apacked[0]) = vec_row; + vsum1 = vec_sum4s(vec_row, vsum1); + Apacked += 16; + + if (tail_N >= 1) { + int a1 = *reinterpret_cast(&a[lda * 4]); + int a2 = 0, a3 = 0, a4 = 0; + + if (tail_N == 2) { + a2 = *reinterpret_cast(&a[lda * 5]); + } + if (tail_N == 3) { + a2 = *reinterpret_cast(&a[lda * 5]); + a3 = *reinterpret_cast(&a[lda * 6]); + } + __vector int vec_a = {a1, a2, a3, a4}; + VecType vec_row = reinterpret_cast(vec_a); + *reinterpret_cast(&Apacked[0]) = vec_row; + vsum2 = vec_sum4s(vec_row, vsum2); + Apacked += 16; + } + + a += 4; + y -= 4; + } + if (y >= 1 && y <= 3) { + VecType vec_tail1 + = reinterpret_cast(vec_splats(uint8_t(0))); + + vec_tail1[0] = a[0]; + vec_tail1[4] = a[lda * 1]; + vec_tail1[8] = a[lda * 2]; + vec_tail1[12] = a[lda * 3]; + if (y == 3) { + vec_tail1[1] = a[1]; + vec_tail1[2] = a[2]; + vec_tail1[5] = a[lda * 1 + 1]; + vec_tail1[6] = a[lda * 1 + 2]; + + vec_tail1[9] = a[lda * 2 + 1]; + vec_tail1[10] = a[lda * 2 + 2]; + vec_tail1[13] = a[lda * 3 + 1]; + vec_tail1[14] = a[lda * 3 + 2]; + } + if (y == 2) { + vec_tail1[1] = a[1]; + vec_tail1[5] = a[lda * 1 + 1]; + vec_tail1[9] = a[lda * 2 + 1]; + vec_tail1[13] = a[lda * 3 + 1]; + } + *reinterpret_cast(&Apacked[0]) = vec_tail1; + vsum1 = vec_sum4s(vec_tail1, vsum1); + + Apacked += 16; + if (tail_N >= 1) { + VecType vec_tail1 + = reinterpret_cast(vec_splats(uint8_t(0))); + if (tail_N == 1) { + vec_tail1[0] = a[lda * 4]; + if (y == 3) { + vec_tail1[1] = a[lda * 4 + 1]; + vec_tail1[2] = a[lda * 4 + 2]; + } + if (y == 2) { vec_tail1[1] = a[lda * 4 + 1]; } + } + if (tail_N == 2) { + vec_tail1[0] = a[lda * 4]; + vec_tail1[4] = a[lda * 5]; + if (y == 3) { + vec_tail1[1] = a[lda * 4 + 1]; + vec_tail1[2] = a[lda * 4 + 2]; + + vec_tail1[5] = a[lda * 5 + 1]; + vec_tail1[6] = a[lda * 5 + 2]; + } + if (y == 2) { + vec_tail1[1] = a[lda * 4 + 1]; + vec_tail1[5] = a[lda * 5 + 1]; + } + } + if (tail_N == 3) { + vec_tail1[0] = a[lda * 4]; + vec_tail1[4] = a[lda * 5]; + vec_tail1[8] = a[lda * 6]; + if (y == 3) { + vec_tail1[1] = a[lda * 4 + 1]; + vec_tail1[2] = a[lda * 4 + 2]; + vec_tail1[5] = a[lda * 5 + 1]; + vec_tail1[6] = a[lda * 5 + 2]; + vec_tail1[9] = a[lda * 6 + 1]; + vec_tail1[10] = a[lda * 6 + 2]; + } + if (y == 2) { + vec_tail1[1] = a[lda * 4 + 1]; + vec_tail1[5] = a[lda * 5 + 1]; + vec_tail1[9] = a[lda * 6 + 1]; + } + } + *reinterpret_cast(&Apacked[0]) = vec_tail1; + vsum2 = vec_sum4s(vec_tail1, vsum2); + Apacked += 16; + } + } + row_sum_eff[0] = vsum1[0]; + row_sum_eff[1] = vsum1[1]; + row_sum_eff[2] = vsum1[2]; + row_sum_eff[3] = vsum1[3]; + row_sum_eff += 4; + + if (tail_N == 1) { row_sum_eff[0] = vsum2[0]; } + if (tail_N == 2) { + row_sum_eff[0] = vsum2[0]; + row_sum_eff[1] = vsum2[1]; + } + if (tail_N == 3) { + row_sum_eff[0] = vsum2[0]; + row_sum_eff[1] = vsum2[1]; + row_sum_eff[2] = vsum2[2]; + } + row_sum_eff += tail_N; + + A += N_Dim * lda; + N_Dim -= N_Dim; + } + + if (N_Dim >= 1 && N_Dim <= 3) { + + const int8_t *a = A; + size_t y = K_Dim; + __vector signed int vsum1 = {0}; + while (y >= 4) { + int a1 = *reinterpret_cast(&a[0]); + int a2 = 0, a3 = 0, a4 = 0; + if (N_Dim == 3) { + a3 = *reinterpret_cast(&a[lda * 2]); + a2 = *reinterpret_cast(&a[lda * 1]); + } + if (N_Dim == 2) { + a2 = *reinterpret_cast(&a[lda * 1]); + } + __vector int vec_a = {a1, a2, a3, a4}; + VecType vec_row1 = reinterpret_cast(vec_a); + *(VecType *)&Apacked[0] = vec_row1; + + vsum1 = vec_sum4s(vec_row1, vsum1); + + a += 4; + Apacked += 16; + y -= 4; + } + if (y >= 1 && y <= 3) { + VecType vec_tail1 + = reinterpret_cast(vec_splats(uint8_t(0))); + + int tail_N = N_Dim; + + if (tail_N >= 1) { + if (tail_N == 1) { + vec_tail1[0] = a[0]; + if (y == 3) { + vec_tail1[1] = a[1]; + vec_tail1[2] = a[2]; + } + if (y == 2) { vec_tail1[1] = a[1]; } + } + if (tail_N == 2) { + vec_tail1[0] = a[0]; + vec_tail1[4] = a[lda]; + if (y == 3) { + vec_tail1[1] = a[1]; + vec_tail1[2] = a[2]; + + vec_tail1[5] = a[lda + 1]; + vec_tail1[6] = a[lda + 2]; + } + if (y == 2) { + vec_tail1[1] = a[1]; + + vec_tail1[5] = a[lda + 1]; + } + } + if (tail_N == 3) { + vec_tail1[0] = a[0]; + vec_tail1[4] = a[lda]; + vec_tail1[8] = a[lda * 2]; + + if (y == 3) { + vec_tail1[1] = a[1]; + vec_tail1[2] = a[2]; + + vec_tail1[5] = a[lda + 1]; + vec_tail1[6] = a[lda + 2]; + + vec_tail1[9] = a[lda * 2 + 1]; + vec_tail1[10] = a[lda * 2 + 2]; + } + if (y == 2) { + vec_tail1[1] = a[1]; + + vec_tail1[5] = a[lda + 1]; + + vec_tail1[9] = a[lda * 2 + 1]; + } + } + *reinterpret_cast(&Apacked[0]) = vec_tail1; + vsum1 = vec_sum4s(vec_tail1, vsum1); + Apacked += 16; + } + } + if (N_Dim == 3) { + row_sum_eff[0] = vsum1[0]; + row_sum_eff[1] = vsum1[1]; + row_sum_eff[2] = vsum1[2]; + } + if (N_Dim == 2) { + row_sum_eff[0] = vsum1[0]; + row_sum_eff[1] = vsum1[1]; + } + if (N_Dim == 1) { row_sum_eff[0] = vsum1[0]; } + row_sum_eff += N_Dim; + } +} + +template +int pack_N16_8bit_V2_lxvp(int K_dim, int N_dim, const int8_t *B, int ldb, + int8_t *Bp, int32_t *row_sum_eff) { + + while (N_dim >= 16) { + int8_t *b = const_cast(B); + size_t y = K_dim; + __vector signed int vsum1 = {0}; + __vector signed int vsum2 = {0}; + __vector signed int vsum3 = {0}; + __vector signed int vsum4 = {0}; + while (y >= 32) { + + __vector_pair row1, row2, row3, row4, row5, row6, row7, row8; + __vector_pair row9, row10, row11, row12, row13, row14, row15, row16; + VecType r1[2] = {0}, r2[2] = {0}, r3[2] = {0}, r4[2] = {0}, + r5[2] = {0}, r6[2] = {0}, r7[2] = {0}, r8[2] = {0}; + VecType r9[2] = {0}, r10[2] = {0}, r11[2] = {0}, r12[2] = {0}, + r13[2] = {0}, r14[2] = {0}, r15[2] = {0}, r16[2] = {0}; + + VecType V0, V1, V2, V3; + VecType D0, D1, D2, D3; + + row1 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[0])); + row2 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[ldb])); + row3 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[ldb * 2])); + row4 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[ldb * 3])); + row5 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[ldb * 4])); + row6 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[ldb * 5])); + row7 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[ldb * 6])); + row8 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[ldb * 7])); + row9 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[ldb * 8])); + row10 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[ldb * 9])); + row11 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[ldb * 10])); + row12 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[ldb * 11])); + row13 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[ldb * 12])); + row14 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[ldb * 13])); + row15 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[ldb * 14])); + row16 = __builtin_vsx_lxvp( + 0, reinterpret_cast<__vector_pair *>(&b[ldb * 15])); + + __builtin_vsx_disassemble_pair(r1, &row1); + __builtin_vsx_disassemble_pair(r2, &row2); + __builtin_vsx_disassemble_pair(r3, &row3); + __builtin_vsx_disassemble_pair(r4, &row4); + __builtin_vsx_disassemble_pair(r5, &row5); + __builtin_vsx_disassemble_pair(r6, &row6); + __builtin_vsx_disassemble_pair(r7, &row7); + __builtin_vsx_disassemble_pair(r8, &row8); + __builtin_vsx_disassemble_pair(r9, &row9); + __builtin_vsx_disassemble_pair(r10, &row10); + __builtin_vsx_disassemble_pair(r11, &row11); + __builtin_vsx_disassemble_pair(r12, &row12); + __builtin_vsx_disassemble_pair(r13, &row13); + __builtin_vsx_disassemble_pair(r14, &row14); + __builtin_vsx_disassemble_pair(r15, &row15); + __builtin_vsx_disassemble_pair(r16, &row16); + + // First 4 Rows and First 16 columns + V0 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r1[0]), + reinterpret_cast<__vector int>(r2[0]))); + V1 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r3[0]), + reinterpret_cast<__vector int>(r4[0]))); + V2 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r1[0]), + reinterpret_cast<__vector int>(r2[0]))); + V3 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r3[0]), + reinterpret_cast<__vector int>(r4[0]))); + + D0 = vec_xxpermdi(V0, V1, 0); + D1 = vec_xxpermdi(V2, V3, 0); + D2 = vec_xxpermdi(V0, V1, 3); + D3 = vec_xxpermdi(V2, V3, 3); + + *(VecType *)&Bp[0] = D0; + *(VecType *)&Bp[64] = D1; + *(VecType *)&Bp[128] = D2; + *(VecType *)&Bp[192] = D3; + + vsum1 = vec_sum4s(D0, vsum1); + vsum1 = vec_sum4s(D1, vsum1); + vsum1 = vec_sum4s(D2, vsum1); + vsum1 = vec_sum4s(D3, vsum1); + + // Next (ldb * 4) 4 Rows and First 16 columns + V0 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r5[0]), + reinterpret_cast<__vector int>(r6[0]))); + V1 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r7[0]), + reinterpret_cast<__vector int>(r8[0]))); + V2 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r5[0]), + reinterpret_cast<__vector int>(r6[0]))); + V3 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r7[0]), + reinterpret_cast<__vector int>(r8[0]))); + + D0 = vec_xxpermdi(V0, V1, 0); + D1 = vec_xxpermdi(V2, V3, 0); + D2 = vec_xxpermdi(V0, V1, 3); + D3 = vec_xxpermdi(V2, V3, 3); + + *(VecType *)&Bp[16] = D0; + *(VecType *)&Bp[80] = D1; + *(VecType *)&Bp[144] = D2; + *(VecType *)&Bp[208] = D3; + + vsum2 = vec_sum4s(D0, vsum2); + vsum2 = vec_sum4s(D1, vsum2); + vsum2 = vec_sum4s(D2, vsum2); + vsum2 = vec_sum4s(D3, vsum2); + + // Third (ldb * 8) 4 Rows and First 16 columns + V0 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r9[0]), + reinterpret_cast<__vector int>(r10[0]))); + V1 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r11[0]), + reinterpret_cast<__vector int>(r12[0]))); + V2 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r9[0]), + reinterpret_cast<__vector int>(r10[0]))); + V3 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r11[0]), + reinterpret_cast<__vector int>(r12[0]))); + + D0 = vec_xxpermdi(V0, V1, 0); + D1 = vec_xxpermdi(V2, V3, 0); + D2 = vec_xxpermdi(V0, V1, 3); + D3 = vec_xxpermdi(V2, V3, 3); + + *(VecType *)&Bp[32] = D0; + *(VecType *)&Bp[96] = D1; + *(VecType *)&Bp[160] = D2; + *(VecType *)&Bp[224] = D3; + + vsum3 = vec_sum4s(D0, vsum3); + vsum3 = vec_sum4s(D1, vsum3); + vsum3 = vec_sum4s(D2, vsum3); + vsum3 = vec_sum4s(D3, vsum3); + + // Fourth (ldb * 12) 4 Rows and First 16 columns + V0 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r13[0]), + reinterpret_cast<__vector int>(r14[0]))); + V1 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r15[0]), + reinterpret_cast<__vector int>(r16[0]))); + V2 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r13[0]), + reinterpret_cast<__vector int>(r14[0]))); + V3 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r15[0]), + reinterpret_cast<__vector int>(r16[0]))); + + D0 = vec_xxpermdi(V0, V1, 0); + D1 = vec_xxpermdi(V2, V3, 0); + D2 = vec_xxpermdi(V0, V1, 3); + D3 = vec_xxpermdi(V2, V3, 3); + + *(VecType *)&Bp[48] = D0; + *(VecType *)&Bp[112] = D1; + *(VecType *)&Bp[176] = D2; + *(VecType *)&Bp[240] = D3; + + vsum4 = vec_sum4s(D0, vsum4); + vsum4 = vec_sum4s(D1, vsum4); + vsum4 = vec_sum4s(D2, vsum4); + vsum4 = vec_sum4s(D3, vsum4); + + Bp += 256; + + V0 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r1[1]), + reinterpret_cast<__vector int>(r2[1]))); + V1 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r3[1]), + reinterpret_cast<__vector int>(r4[1]))); + V2 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r1[1]), + reinterpret_cast<__vector int>(r2[1]))); + V3 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r3[1]), + reinterpret_cast<__vector int>(r4[1]))); + + D0 = vec_xxpermdi(V0, V1, 0); + D1 = vec_xxpermdi(V2, V3, 0); + D2 = vec_xxpermdi(V0, V1, 3); + D3 = vec_xxpermdi(V2, V3, 3); + + *(VecType *)&Bp[0] = D0; + *(VecType *)&Bp[64] = D1; + *(VecType *)&Bp[128] = D2; + *(VecType *)&Bp[192] = D3; + + vsum1 = vec_sum4s(D0, vsum1); + vsum1 = vec_sum4s(D1, vsum1); + vsum1 = vec_sum4s(D2, vsum1); + vsum1 = vec_sum4s(D3, vsum1); + + // Next (ldb * 4) 4 Rows and Second 16 columns + V0 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r5[1]), + reinterpret_cast<__vector int>(r6[1]))); + V1 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r7[1]), + reinterpret_cast<__vector int>(r8[1]))); + V2 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r5[1]), + reinterpret_cast<__vector int>(r6[1]))); + V3 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r7[1]), + reinterpret_cast<__vector int>(r8[1]))); + + D0 = vec_xxpermdi(V0, V1, 0); + D1 = vec_xxpermdi(V2, V3, 0); + D2 = vec_xxpermdi(V0, V1, 3); + D3 = vec_xxpermdi(V2, V3, 3); + + *(VecType *)&Bp[16] = D0; + *(VecType *)&Bp[80] = D1; + *(VecType *)&Bp[144] = D2; + *(VecType *)&Bp[208] = D3; + + vsum2 = vec_sum4s(D0, vsum2); + vsum2 = vec_sum4s(D1, vsum2); + vsum2 = vec_sum4s(D2, vsum2); + vsum2 = vec_sum4s(D3, vsum2); + + // First 4 Rows and First 16 columns + V0 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r9[1]), + reinterpret_cast<__vector int>(r10[1]))); + V1 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r11[1]), + reinterpret_cast<__vector int>(r12[1]))); + V2 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r9[1]), + reinterpret_cast<__vector int>(r10[1]))); + V3 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r11[1]), + reinterpret_cast<__vector int>(r12[1]))); + + D0 = vec_xxpermdi(V0, V1, 0); + D1 = vec_xxpermdi(V2, V3, 0); + D2 = vec_xxpermdi(V0, V1, 3); + D3 = vec_xxpermdi(V2, V3, 3); + + *(VecType *)&Bp[32] = D0; + *(VecType *)&Bp[96] = D1; + *(VecType *)&Bp[160] = D2; + *(VecType *)&Bp[224] = D3; + + vsum3 = vec_sum4s(D0, vsum3); + vsum3 = vec_sum4s(D1, vsum3); + vsum3 = vec_sum4s(D2, vsum3); + vsum3 = vec_sum4s(D3, vsum3); + + // Next (ldb * 4) 4 Rows and First 16 columns + V0 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r13[1]), + reinterpret_cast<__vector int>(r14[1]))); + V1 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(r15[1]), + reinterpret_cast<__vector int>(r16[1]))); + V2 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r13[1]), + reinterpret_cast<__vector int>(r14[1]))); + V3 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(r15[1]), + reinterpret_cast<__vector int>(r16[1]))); + + D0 = vec_xxpermdi(V0, V1, 0); + D1 = vec_xxpermdi(V2, V3, 0); + D2 = vec_xxpermdi(V0, V1, 3); + D3 = vec_xxpermdi(V2, V3, 3); + + *(VecType *)&Bp[48] = D0; + *(VecType *)&Bp[112] = D1; + *(VecType *)&Bp[176] = D2; + *(VecType *)&Bp[240] = D3; + + vsum4 = vec_sum4s(D0, vsum4); + vsum4 = vec_sum4s(D1, vsum4); + vsum4 = vec_sum4s(D2, vsum4); + vsum4 = vec_sum4s(D3, vsum4); + + y -= 32; + b += 32; + Bp += 256; + } + while (y >= 16) { + // First 4th row and 16 Columns + VecType b1 = *reinterpret_cast(&b[0]); + VecType b2 = *reinterpret_cast(&b[ldb]); + VecType b3 = *reinterpret_cast(&b[ldb * 2]); + VecType b4 = *reinterpret_cast(&b[ldb * 3]); + + VecType vec_even12 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(b1), + reinterpret_cast<__vector int>(b2))); + VecType vec_even34 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(b3), + reinterpret_cast<__vector int>(b4))); + VecType vec_odd12 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(b1), + reinterpret_cast<__vector int>(b2))); + VecType vec_odd34 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(b3), + reinterpret_cast<__vector int>(b4))); + + VecType vx_row1 = vec_xxpermdi(vec_even12, vec_even34, 0); + VecType vx_row5 = vec_xxpermdi(vec_odd12, vec_odd34, 0); + VecType vx_row9 = vec_xxpermdi(vec_even12, vec_even34, 3); + VecType vx_row13 = vec_xxpermdi(vec_odd12, vec_odd34, 3); + + *reinterpret_cast(&Bp[0]) = vx_row1; + *reinterpret_cast(&Bp[64]) = vx_row5; + *reinterpret_cast(&Bp[128]) = vx_row9; + *reinterpret_cast(&Bp[192]) = vx_row13; + + vsum1 = vec_sum4s(vx_row1, vsum1); + vsum1 = vec_sum4s(vx_row5, vsum1); + vsum1 = vec_sum4s(vx_row9, vsum1); + vsum1 = vec_sum4s(vx_row13, vsum1); + + // Second 4th Row and 16 Columns + b1 = *reinterpret_cast(&b[ldb * 4]); + b2 = *reinterpret_cast(&b[ldb * 5]); + b3 = *reinterpret_cast(&b[ldb * 6]); + b4 = *reinterpret_cast(&b[ldb * 7]); + + vec_even12 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(b1), + reinterpret_cast<__vector int>(b2))); + vec_even34 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(b3), + reinterpret_cast<__vector int>(b4))); + vec_odd12 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(b1), + reinterpret_cast<__vector int>(b2))); + vec_odd34 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(b3), + reinterpret_cast<__vector int>(b4))); + + VecType vx_row2 = vec_xxpermdi(vec_even12, vec_even34, 0); + VecType vx_row6 = vec_xxpermdi(vec_odd12, vec_odd34, 0); + VecType vx_row10 = vec_xxpermdi(vec_even12, vec_even34, 3); + VecType vx_row14 = vec_xxpermdi(vec_odd12, vec_odd34, 3); + + *reinterpret_cast(&Bp[16]) = vx_row2; + *reinterpret_cast(&Bp[80]) = vx_row6; + *reinterpret_cast(&Bp[144]) = vx_row10; + *reinterpret_cast(&Bp[208]) = vx_row14; + + vsum2 = vec_sum4s(vx_row2, vsum2); + vsum2 = vec_sum4s(vx_row6, vsum2); + vsum2 = vec_sum4s(vx_row10, vsum2); + vsum2 = vec_sum4s(vx_row14, vsum2); + + b1 = *reinterpret_cast(&b[ldb * 8]); + b2 = *reinterpret_cast(&b[ldb * 9]); + b3 = *reinterpret_cast(&b[ldb * 10]); + b4 = *reinterpret_cast(&b[ldb * 11]); + + vec_even12 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(b1), + reinterpret_cast<__vector int>(b2))); + vec_even34 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(b3), + reinterpret_cast<__vector int>(b4))); + vec_odd12 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(b1), + reinterpret_cast<__vector int>(b2))); + vec_odd34 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(b3), + reinterpret_cast<__vector int>(b4))); + + VecType vx_row3 = vec_xxpermdi(vec_even12, vec_even34, 0); + VecType vx_row7 = vec_xxpermdi(vec_odd12, vec_odd34, 0); + VecType vx_row11 = vec_xxpermdi(vec_even12, vec_even34, 3); + VecType vx_row15 = vec_xxpermdi(vec_odd12, vec_odd34, 3); + + *reinterpret_cast(&Bp[32]) = vx_row3; + *reinterpret_cast(&Bp[96]) = vx_row7; + *reinterpret_cast(&Bp[160]) = vx_row11; + *reinterpret_cast(&Bp[224]) = vx_row15; + + vsum3 = vec_sum4s(vx_row3, vsum3); + vsum3 = vec_sum4s(vx_row7, vsum3); + vsum3 = vec_sum4s(vx_row11, vsum3); + vsum3 = vec_sum4s(vx_row15, vsum3); + + b1 = *reinterpret_cast(&b[ldb * 12]); + b2 = *reinterpret_cast(&b[ldb * 13]); + b3 = *reinterpret_cast(&b[ldb * 14]); + b4 = *reinterpret_cast(&b[ldb * 15]); + + vec_even12 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(b1), + reinterpret_cast<__vector int>(b2))); + vec_even34 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(b3), + reinterpret_cast<__vector int>(b4))); + vec_odd12 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(b1), + reinterpret_cast<__vector int>(b2))); + vec_odd34 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(b3), + reinterpret_cast<__vector int>(b4))); + + VecType vx_row4 = vec_xxpermdi(vec_even12, vec_even34, 0); + VecType vx_row8 = vec_xxpermdi(vec_odd12, vec_odd34, 0); + VecType vx_row12 = vec_xxpermdi(vec_even12, vec_even34, 3); + VecType vx_row16 = vec_xxpermdi(vec_odd12, vec_odd34, 3); + + *reinterpret_cast(&Bp[48]) = vx_row4; + *reinterpret_cast(&Bp[112]) = vx_row8; + *reinterpret_cast(&Bp[176]) = vx_row12; + *reinterpret_cast(&Bp[240]) = vx_row16; + + vsum4 = vec_sum4s(vx_row4, vsum4); + vsum4 = vec_sum4s(vx_row8, vsum4); + vsum4 = vec_sum4s(vx_row12, vsum4); + vsum4 = vec_sum4s(vx_row16, vsum4); + + b += 16; + Bp += 256; + y -= 16; + } + while (y >= 8) { + VecType V0, V1, V2, V3; + VecType D0, D1, D2, D3; + __vector unsigned char swizA = { + 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27}; + __vector unsigned char swizB = { + 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31}; + *(signed long long *)&V0[0] = *(signed long long *)&b[ldb * 0]; + *(signed long long *)&V0[8] = *(signed long long *)&b[ldb * 1]; + *(signed long long *)&V1[0] = *(signed long long *)&b[ldb * 2]; + *(signed long long *)&V1[8] = *(signed long long *)&b[ldb * 3]; + *(signed long long *)&V2[0] = *(signed long long *)&b[ldb * 4]; + *(signed long long *)&V2[8] = *(signed long long *)&b[ldb * 5]; + *(signed long long *)&V3[0] = *(signed long long *)&b[ldb * 6]; + *(signed long long *)&V3[8] = *(signed long long *)&b[ldb * 7]; + + D0 = vec_perm(V0, V1, swizA); + D1 = vec_perm(V2, V3, swizA); + D2 = vec_perm(V0, V1, swizB); + D3 = vec_perm(V2, V3, swizB); + + *(VecType *)&Bp[0] = D0; + *(VecType *)&Bp[16] = D1; + *(VecType *)&Bp[64] = D2; + *(VecType *)&Bp[80] = D3; + + vsum1 = vec_sum4s(D0, vsum1); + vsum1 = vec_sum4s(D2, vsum1); + vsum2 = vec_sum4s(D1, vsum2); + vsum2 = vec_sum4s(D3, vsum2); + + *(signed long long *)&V0[0] = *(signed long long *)&b[ldb * 8]; + *(signed long long *)&V0[8] = *(signed long long *)&b[ldb * 9]; + *(signed long long *)&V1[0] = *(signed long long *)&b[ldb * 10]; + *(signed long long *)&V1[8] = *(signed long long *)&b[ldb * 11]; + *(signed long long *)&V2[0] = *(signed long long *)&b[ldb * 12]; + *(signed long long *)&V2[8] = *(signed long long *)&b[ldb * 13]; + *(signed long long *)&V3[0] = *(signed long long *)&b[ldb * 14]; + *(signed long long *)&V3[8] = *(signed long long *)&b[ldb * 15]; + + D0 = vec_perm(V0, V1, swizA); + D1 = vec_perm(V2, V3, swizA); + D2 = vec_perm(V0, V1, swizB); + D3 = vec_perm(V2, V3, swizB); + + *(VecType *)&Bp[32] = D0; + *(VecType *)&Bp[48] = D1; + *(VecType *)&Bp[96] = D2; + *(VecType *)&Bp[112] = D3; + + vsum3 = vec_sum4s(D0, vsum3); + vsum3 = vec_sum4s(D2, vsum3); + vsum4 = vec_sum4s(D1, vsum4); + vsum4 = vec_sum4s(D3, vsum4); + + Bp += 16 * 8; + b += 8; + y -= 8; + } + while (y >= 4) { + int a1 = *reinterpret_cast(&b[0]); + int a2 = *reinterpret_cast(&b[ldb * 1]); + int a3 = *reinterpret_cast(&b[ldb * 2]); + int a4 = *reinterpret_cast(&b[ldb * 3]); + __vector int vec_a = {a1, a2, a3, a4}; + VecType vec_row1 = reinterpret_cast(vec_a); + *reinterpret_cast(&Bp[0]) = vec_row1; + vsum1 = vec_sum4s(vec_row1, vsum1); + + a1 = *reinterpret_cast(&b[ldb * 4]); + a2 = *reinterpret_cast(&b[ldb * 5]); + a3 = *reinterpret_cast(&b[ldb * 6]); + a4 = *reinterpret_cast(&b[ldb * 7]); + __vector int vec_a1 = {a1, a2, a3, a4}; + vec_row1 = reinterpret_cast(vec_a1); + *reinterpret_cast(&Bp[16]) = vec_row1; + vsum2 = vec_sum4s(vec_row1, vsum2); + + a1 = *reinterpret_cast(&b[ldb * 8]); + a2 = *reinterpret_cast(&b[ldb * 9]); + a3 = *reinterpret_cast(&b[ldb * 10]); + a4 = *reinterpret_cast(&b[ldb * 11]); + __vector int vec_a2 = {a1, a2, a3, a4}; + vec_row1 = reinterpret_cast(vec_a2); + *reinterpret_cast(&Bp[32]) = vec_row1; + vsum3 = vec_sum4s(vec_row1, vsum3); + + a1 = *reinterpret_cast(&b[ldb * 12]); + a2 = *reinterpret_cast(&b[ldb * 13]); + a3 = *reinterpret_cast(&b[ldb * 14]); + a4 = *reinterpret_cast(&b[ldb * 15]); + __vector int vec_a3 = {a1, a2, a3, a4}; + vec_row1 = reinterpret_cast(vec_a3); + *reinterpret_cast(&Bp[48]) = vec_row1; + vsum4 = vec_sum4s(vec_row1, vsum4); + + Bp += 64; + y -= 4; + b += 4; + } + if (y >= 1 && y <= 3) { + VecType vec_tail1 + = reinterpret_cast(vec_splats(uint8_t(0))); + VecType vec_tail2 + = reinterpret_cast(vec_splats(uint8_t(0))); + VecType vec_tail3 + = reinterpret_cast(vec_splats(uint8_t(0))); + VecType vec_tail4 + = reinterpret_cast(vec_splats(uint8_t(0))); + + vec_tail1[0] = b[0]; + vec_tail1[4] = b[ldb]; + vec_tail1[8] = b[ldb * 2]; + vec_tail1[12] = b[ldb * 3]; + + vec_tail2[0] = b[ldb * 4]; + vec_tail2[4] = b[ldb * 5]; + vec_tail2[8] = b[ldb * 6]; + vec_tail2[12] = b[ldb * 7]; + + vec_tail3[0] = b[ldb * 8]; + vec_tail3[4] = b[ldb * 9]; + vec_tail3[8] = b[ldb * 10]; + vec_tail3[12] = b[ldb * 11]; + + vec_tail4[0] = b[ldb * 12]; + vec_tail4[4] = b[ldb * 13]; + vec_tail4[8] = b[ldb * 14]; + vec_tail4[12] = b[ldb * 15]; + + if (y == 3) { + vec_tail1[1] = b[1]; + vec_tail1[5] = b[ldb + 1]; + vec_tail1[9] = b[ldb * 2 + 1]; + vec_tail1[13] = b[ldb * 3 + 1]; + + vec_tail2[1] = b[ldb * 4 + 1]; + vec_tail2[5] = b[ldb * 5 + 1]; + vec_tail2[9] = b[ldb * 6 + 1]; + vec_tail2[13] = b[ldb * 7 + 1]; + + vec_tail1[2] = b[2]; + vec_tail1[6] = b[ldb + 2]; + vec_tail1[10] = b[ldb * 2 + 2]; + vec_tail1[14] = b[ldb * 3 + 2]; + + vec_tail2[2] = b[ldb * 4 + 2]; + vec_tail2[6] = b[ldb * 5 + 2]; + vec_tail2[10] = b[ldb * 6 + 2]; + vec_tail2[14] = b[ldb * 7 + 2]; + + vec_tail3[1] = b[ldb * 8 + 1]; + vec_tail3[5] = b[ldb * 9 + 1]; + vec_tail3[9] = b[ldb * 10 + 1]; + vec_tail3[13] = b[ldb * 11 + 1]; + + vec_tail3[2] = b[ldb * 8 + 2]; + vec_tail3[6] = b[ldb * 9 + 2]; + vec_tail3[10] = b[ldb * 10 + 2]; + vec_tail3[14] = b[ldb * 11 + 2]; + + vec_tail4[1] = b[ldb * 12 + 1]; + vec_tail4[5] = b[ldb * 13 + 1]; + vec_tail4[9] = b[ldb * 14 + 1]; + vec_tail4[13] = b[ldb * 15 + 1]; + + vec_tail4[2] = b[ldb * 12 + 2]; + vec_tail4[6] = b[ldb * 13 + 2]; + vec_tail4[10] = b[ldb * 14 + 2]; + vec_tail4[14] = b[ldb * 15 + 2]; + } + if (y == 2) { + vec_tail1[1] = b[1]; + vec_tail1[5] = b[ldb + 1]; + vec_tail1[9] = b[ldb * 2 + 1]; + vec_tail1[13] = b[ldb * 3 + 1]; + + vec_tail2[1] = b[ldb * 4 + 1]; + vec_tail2[5] = b[ldb * 5 + 1]; + vec_tail2[9] = b[ldb * 6 + 1]; + vec_tail2[13] = b[ldb * 7 + 1]; + + vec_tail3[1] = b[ldb * 8 + 1]; + vec_tail3[5] = b[ldb * 9 + 1]; + vec_tail3[9] = b[ldb * 10 + 1]; + vec_tail3[13] = b[ldb * 11 + 1]; + + vec_tail4[1] = b[ldb * 12 + 1]; + vec_tail4[5] = b[ldb * 13 + 1]; + vec_tail4[9] = b[ldb * 14 + 1]; + vec_tail4[13] = b[ldb * 15 + 1]; + } + *reinterpret_cast(&Bp[0]) = vec_tail1; + *reinterpret_cast(&Bp[16]) = vec_tail2; + *reinterpret_cast(&Bp[32]) = vec_tail3; + *reinterpret_cast(&Bp[48]) = vec_tail4; + + vsum1 = vec_sum4s(vec_tail1, vsum1); + vsum2 = vec_sum4s(vec_tail2, vsum2); + vsum3 = vec_sum4s(vec_tail3, vsum3); + vsum4 = vec_sum4s(vec_tail4, vsum4); + Bp += 64; + } + row_sum_eff[0] = vsum1[0]; + row_sum_eff[1] = vsum1[1]; + row_sum_eff[2] = vsum1[2]; + row_sum_eff[3] = vsum1[3]; + + row_sum_eff[4] = vsum2[0]; + row_sum_eff[5] = vsum2[1]; + row_sum_eff[6] = vsum2[2]; + row_sum_eff[7] = vsum2[3]; + + row_sum_eff[8] = vsum3[0]; + row_sum_eff[9] = vsum3[1]; + row_sum_eff[10] = vsum3[2]; + row_sum_eff[11] = vsum3[3]; + + row_sum_eff[12] = vsum4[0]; + row_sum_eff[13] = vsum4[1]; + row_sum_eff[14] = vsum4[2]; + row_sum_eff[15] = vsum4[3]; + row_sum_eff += 16; + N_dim -= 16; + B += 16 * ldb; + } + + if (N_dim > 12 && N_dim < 16) { + tailBlock16_12xK<__vector signed char>( + K_dim, N_dim, B, ldb, Bp, row_sum_eff); + } else if (N_dim >= 1 && N_dim <= 12) { + tailBlock16_8xK<__vector signed char>( + K_dim, N_dim, B, ldb, Bp, row_sum_eff); + } + return 0; +} + +template +void pack_N16_8bit_V2(int K_Dim, int N_Dim, const int8_t *A, int lda /* N */, + int8_t *Apacked, int32_t *row_sum_eff) { + + int K_block = (K_Dim + 3) & (~3); + int N_block = (N_Dim + 3) & (~3); + + typedef __vector unsigned char vec_t; + vec_t mask = {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15}; + + while (N_Dim >= 16) { + const int8_t *a = A; + + size_t y = K_Dim; + while (y >= 16) { + //1st 4 Columns + VecType a1 = *reinterpret_cast(&a[0]); + VecType a2 = *reinterpret_cast(&a[lda]); + VecType a3 = *reinterpret_cast(&a[lda * 2]); + VecType a4 = *reinterpret_cast(&a[lda * 3]); + + VecType vec_even12 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(a1), + reinterpret_cast<__vector int>(a2))); + VecType vec_even34 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(a3), + reinterpret_cast<__vector int>(a4))); + VecType vec_odd12 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(a1), + reinterpret_cast<__vector int>(a2))); + VecType vec_odd34 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(a3), + reinterpret_cast<__vector int>(a4))); + + VecType vx_row1 = vec_xxpermdi(vec_even12, vec_even34, 0); + VecType vx_row5 = vec_xxpermdi(vec_odd12, vec_odd34, 0); + VecType vx_row9 = vec_xxpermdi(vec_even12, vec_even34, 3); + VecType vx_row13 = vec_xxpermdi(vec_odd12, vec_odd34, 3); + + *reinterpret_cast(&Apacked[0]) = vx_row1; + *reinterpret_cast(&Apacked[64]) = vx_row5; + *reinterpret_cast(&Apacked[128]) = vx_row9; + *reinterpret_cast(&Apacked[192]) = vx_row13; + + // 2nd 4 Columns + a1 = *reinterpret_cast(&a[lda * 4]); + a2 = *reinterpret_cast(&a[lda * 5]); + a3 = *reinterpret_cast(&a[lda * 6]); + a4 = *reinterpret_cast(&a[lda * 7]); + + vec_even12 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(a1), + reinterpret_cast<__vector int>(a2))); + vec_even34 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(a3), + reinterpret_cast<__vector int>(a4))); + vec_odd12 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(a1), + reinterpret_cast<__vector int>(a2))); + vec_odd34 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(a3), + reinterpret_cast<__vector int>(a4))); + + VecType vx_row2 = vec_xxpermdi(vec_even12, vec_even34, 0); + VecType vx_row6 = vec_xxpermdi(vec_odd12, vec_odd34, 0); + VecType vx_row10 = vec_xxpermdi(vec_even12, vec_even34, 3); + VecType vx_row14 = vec_xxpermdi(vec_odd12, vec_odd34, 3); + + *reinterpret_cast(&Apacked[16]) = vx_row2; + *reinterpret_cast(&Apacked[80]) = vx_row6; + *reinterpret_cast(&Apacked[144]) = vx_row10; + *reinterpret_cast(&Apacked[208]) = vx_row14; + + // 3rd 4 Columns + a1 = *reinterpret_cast(&a[lda * 8]); + a2 = *reinterpret_cast(&a[lda * 9]); + a3 = *reinterpret_cast(&a[lda * 10]); + a4 = *reinterpret_cast(&a[lda * 11]); + + vec_even12 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(a1), + reinterpret_cast<__vector int>(a2))); + vec_even34 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(a3), + reinterpret_cast<__vector int>(a4))); + vec_odd12 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(a1), + reinterpret_cast<__vector int>(a2))); + vec_odd34 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(a3), + reinterpret_cast<__vector int>(a4))); + + VecType vx_row3 = vec_xxpermdi(vec_even12, vec_even34, 0); + VecType vx_row7 = vec_xxpermdi(vec_odd12, vec_odd34, 0); + VecType vx_row11 = vec_xxpermdi(vec_even12, vec_even34, 3); + VecType vx_row15 = vec_xxpermdi(vec_odd12, vec_odd34, 3); + + *reinterpret_cast(&Apacked[32]) = vx_row3; + *reinterpret_cast(&Apacked[96]) = vx_row7; + *reinterpret_cast(&Apacked[160]) = vx_row11; + *reinterpret_cast(&Apacked[224]) = vx_row15; + + // 4th 4 Columns + a1 = *reinterpret_cast(&a[lda * 12]); + a2 = *reinterpret_cast(&a[lda * 13]); + a3 = *reinterpret_cast(&a[lda * 14]); + a4 = *reinterpret_cast(&a[lda * 15]); + + vec_even12 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(a1), + reinterpret_cast<__vector int>(a2))); + vec_even34 = reinterpret_cast( + vec_mergee(reinterpret_cast<__vector int>(a3), + reinterpret_cast<__vector int>(a4))); + vec_odd12 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(a1), + reinterpret_cast<__vector int>(a2))); + vec_odd34 = reinterpret_cast( + vec_mergeo(reinterpret_cast<__vector int>(a3), + reinterpret_cast<__vector int>(a4))); + + VecType vx_row4 = vec_xxpermdi(vec_even12, vec_even34, 0); + VecType vx_row8 = vec_xxpermdi(vec_odd12, vec_odd34, 0); + VecType vx_row12 = vec_xxpermdi(vec_even12, vec_even34, 3); + VecType vx_row16 = vec_xxpermdi(vec_odd12, vec_odd34, 3); + + *reinterpret_cast(&Apacked[48]) = vx_row4; + *reinterpret_cast(&Apacked[112]) = vx_row8; + *reinterpret_cast(&Apacked[176]) = vx_row12; + *reinterpret_cast(&Apacked[240]) = vx_row16; + + y -= 16; + Apacked += 256; + a += 16; + } + + while (y >= 4) { + + int a1 = *reinterpret_cast(&a[0]); + int a2 = *reinterpret_cast(&a[lda * 1]); + int a3 = *reinterpret_cast(&a[lda * 2]); + int a4 = *reinterpret_cast(&a[lda * 3]); + __vector int vec_a = {a1, a2, a3, a4}; + VecType vec_row1 = reinterpret_cast(vec_a); + *reinterpret_cast(&Apacked[0]) = vec_row1; + + // Next 4 Column + a1 = *reinterpret_cast(&a[lda * 4]); + a2 = *reinterpret_cast(&a[lda * 5]); + a3 = *reinterpret_cast(&a[lda * 6]); + a4 = *reinterpret_cast(&a[lda * 7]); + __vector int vec_a1 = {a1, a2, a3, a4}; + vec_row1 = reinterpret_cast(vec_a1); + *reinterpret_cast(&Apacked[16]) = vec_row1; + + // Next 4 Column + a1 = *reinterpret_cast(&a[lda * 8]); + a2 = *reinterpret_cast(&a[lda * 9]); + a3 = *reinterpret_cast(&a[lda * 10]); + a4 = *reinterpret_cast(&a[lda * 11]); + __vector int vec_a2 = {a1, a2, a3, a4}; + vec_row1 = reinterpret_cast(vec_a2); + *reinterpret_cast(&Apacked[32]) = vec_row1; + + // Next 4 Column + a1 = *reinterpret_cast(&a[lda * 12]); + a2 = *reinterpret_cast(&a[lda * 13]); + a3 = *reinterpret_cast(&a[lda * 14]); + a4 = *reinterpret_cast(&a[lda * 15]); + __vector int vec_a3 = {a1, a2, a3, a4}; + vec_row1 = reinterpret_cast(vec_a3); + *reinterpret_cast(&Apacked[48]) = vec_row1; + + Apacked += 64; + a += 4; + y -= 4; + } + + if (y <= 3 && y >= 1) { + VecType vec_tail1 + = reinterpret_cast(vec_splats(uint8_t(0))); + VecType vec_tail2 + = reinterpret_cast(vec_splats(uint8_t(0))); + VecType vec_tail3 + = reinterpret_cast(vec_splats(uint8_t(0))); + VecType vec_tail4 + = reinterpret_cast(vec_splats(uint8_t(0))); + + vec_tail1[0] = a[0]; + vec_tail1[4] = a[lda]; + vec_tail1[8] = a[lda * 2]; + vec_tail1[12] = a[lda * 3]; + + vec_tail2[0] = a[lda * 4]; + vec_tail2[4] = a[lda * 5]; + vec_tail2[8] = a[lda * 6]; + vec_tail2[12] = a[lda * 7]; + + vec_tail3[0] = a[lda * 8]; + vec_tail3[4] = a[lda * 9]; + vec_tail3[8] = a[lda * 10]; + vec_tail3[12] = a[lda * 11]; + + vec_tail4[0] = a[lda * 12]; + vec_tail4[4] = a[lda * 13]; + vec_tail4[8] = a[lda * 14]; + vec_tail4[12] = a[lda * 15]; + if (y == 3) { + vec_tail1[1] = a[1]; + vec_tail1[2] = a[2]; + vec_tail1[5] = a[lda + 1]; + vec_tail1[6] = a[lda + 2]; + + vec_tail1[9] = a[lda * 2 + 1]; + vec_tail1[10] = a[lda * 2 + 2]; + vec_tail1[13] = a[lda * 3 + 1]; + vec_tail1[14] = a[lda * 3 + 2]; + + // Next 4 rows + vec_tail2[1] = a[lda * 4 + 1]; + vec_tail2[2] = a[lda * 4 + 2]; + vec_tail2[5] = a[lda * 5 + 1]; + vec_tail2[6] = a[lda * 5 + 2]; + vec_tail2[9] = a[lda * 6 + 1]; + vec_tail2[10] = a[lda * 6 + 2]; + vec_tail2[13] = a[lda * 7 + 1]; + vec_tail2[14] = a[lda * 7 + 2]; + + vec_tail3[1] = a[lda * 8 + 1]; + vec_tail3[2] = a[lda * 8 + 2]; + vec_tail3[5] = a[lda * 9 + 1]; + vec_tail3[6] = a[lda * 9 + 2]; + vec_tail3[9] = a[lda * 10 + 1]; + vec_tail3[10] = a[lda * 10 + 2]; + vec_tail3[13] = a[lda * 11 + 1]; + vec_tail3[14] = a[lda * 11 + 2]; + + vec_tail4[1] = a[lda * 12 + 1]; + vec_tail4[2] = a[lda * 12 + 2]; + vec_tail4[5] = a[lda * 13 + 1]; + vec_tail4[6] = a[lda * 13 + 2]; + vec_tail4[9] = a[lda * 14 + 1]; + vec_tail4[10] = a[lda * 14 + 2]; + vec_tail4[13] = a[lda * 15 + 1]; + vec_tail4[14] = a[lda * 15 + 2]; + } + if (y == 2) { + vec_tail1[1] = a[1]; + vec_tail1[5] = a[lda + 1]; + vec_tail1[9] = a[lda * 2 + 1]; + vec_tail1[13] = a[lda * 3 + 1]; + + vec_tail2[1] = a[lda * 4 + 1]; + vec_tail2[5] = a[lda * 5 + 1]; + vec_tail2[9] = a[lda * 6 + 1]; + vec_tail2[13] = a[lda * 7 + 1]; + + vec_tail3[1] = a[lda * 8 + 1]; + vec_tail3[5] = a[lda * 9 + 1]; + vec_tail3[9] = a[lda * 10 + 1]; + vec_tail3[13] = a[lda * 11 + 1]; + + vec_tail4[1] = a[lda * 12 + 1]; + vec_tail4[5] = a[lda * 13 + 1]; + vec_tail4[9] = a[lda * 14 + 1]; + vec_tail4[13] = a[lda * 15 + 1]; + } + *reinterpret_cast(&Apacked[0]) = vec_tail1; + *reinterpret_cast(&Apacked[16]) = vec_tail2; + *reinterpret_cast(&Apacked[32]) = vec_tail3; + *reinterpret_cast(&Apacked[48]) = vec_tail4; + Apacked += 64; + } + N_Dim -= 16; + A += lda * 16; + } + + if (N_Dim > 12 && N_Dim < 16) { + tailBlock16_12xK<__vector signed char>( + K_Dim, N_Dim, A, lda, Apacked, row_sum_eff); + } else if (N_Dim >= 1 && N_Dim <= 12) { + tailBlock16_8xK<__vector signed char>( + K_Dim, N_Dim, A, lda, Apacked, row_sum_eff); + } +} -int pack_N16_8bit(dim_t k, dim_t m, const int8_t *a, dim_t lda, int8_t *ap) { +inline int pack_N16_8bit( + dim_t k, dim_t m, const int8_t *a, dim_t lda, int8_t *ap) { int32_t i, j; int32_t m_cap = (m + 3) & ~3; int32_t k_cap = (k + 3) & ~3; @@ -1135,7 +5454,152 @@ int pack_N16_8bit(dim_t k, dim_t m, const int8_t *a, dim_t lda, int8_t *ap) { return 0; } -int pack_T8_8bit(dim_t k, dim_t n, const uint8_t *b, dim_t ldb, uint8_t *bp) { +template +inline int pack_T8_8bit_V2_signed(dim_t k, dim_t n, const b_type *b, dim_t ldb, + uint8_t *bp, bool is_signed) { + int32_t i, j; + int32_t kcell, cell, koff, noff, krows, k8, n8; + int32_t n_cap = (n + 3) & ~3; + int32_t k_cap = (k + 3) & ~3; + krows = (k + 3) >> 2; + k8 = (k >> 3) << 3; + n8 = (n >> 3) << 3; + + const uint8_t BitFlipValue = (is_signed ? 0x80 : 0); + + VecType vmask = reinterpret_cast(vec_splats(BitFlipValue)); + + // MAIN BLOCK + for (i = 0; i < k8; i += 8) { + for (j = 0; j < n8; j += 8) { + VecType V0, V1, V2, V3; + VecType D01A, D01B, D23A, D23B; + VecType D0, D1, D2, D3; + vec_t swizA + = {0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23}; + vec_t swizB = {8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, + 15, 31}; + vec_t swizL + = {0, 1, 16, 17, 2, 3, 18, 19, 4, 5, 20, 21, 6, 7, 22, 23}; + vec_t swizR = {8, 9, 24, 25, 10, 11, 26, 27, 12, 13, 28, 29, 14, 15, + 30, 31}; + uint8_t *dest; + + *(signed long long *)&V0[0] + = *(signed long long *)&b[ldb * (i + 0) + j]; + *(signed long long *)&V1[0] + = *(signed long long *)&b[ldb * (i + 1) + j]; + *(signed long long *)&V2[0] + = *(signed long long *)&b[ldb * (i + 2) + j]; + *(signed long long *)&V3[0] + = *(signed long long *)&b[ldb * (i + 3) + j]; + *(signed long long *)&V0[8] + = *(signed long long *)&b[ldb * (i + 4) + j]; + *(signed long long *)&V1[8] + = *(signed long long *)&b[ldb * (i + 5) + j]; + *(signed long long *)&V2[8] + = *(signed long long *)&b[ldb * (i + 6) + j]; + *(signed long long *)&V3[8] + = *(signed long long *)&b[ldb * (i + 7) + j]; + + D01A = vec_perm(V0, V1, swizA); + D01B = vec_perm(V0, V1, swizB); + D23A = vec_perm(V2, V3, swizA); + D23B = vec_perm(V2, V3, swizB); + D0 = vec_perm(D01A, D23A, swizL); + D1 = vec_perm(D01A, D23A, swizR); + D2 = vec_perm(D01B, D23B, swizL); + D3 = vec_perm(D01B, D23B, swizR); + + dest = &bp[16 * ((j >> 2) * krows + (i >> 1))]; + + *(vec_t *)&dest[0] + = reinterpret_cast(vec_add(D0, vmask)); //D0; + *(vec_t *)&dest[16] + = reinterpret_cast(vec_add(D1, vmask)); // D1; + *(vec_t *)&dest[32] + = reinterpret_cast(vec_add(D2, vmask)); //D2; + *(vec_t *)&dest[48] + = reinterpret_cast(vec_add(D3, vmask)); //D3; + } + } + + // HIGH EDGE IN N DIRECTION + for (i = 0; i < k8; ++i) { + for (j = n8; j < n_cap; ++j) { + kcell = i >> 2; + // special handling if j is in a PARTIAL last "group of 8" + int32_t maingroup = (j & (~7)) < (n & (~7)); + int32_t columns_done = ((j & (~7)) >> 3) << 1; + int32_t groupwidth = (maingroup || ((n & 7) > 4)) ? 2 : 1; + int32_t j_hiflag = (j & 4) >> 2; + cell = columns_done * krows + kcell * groupwidth + j_hiflag; + koff = i & 3; + noff = j & 3; + if (j < n) + bp[16 * cell + 4 * noff + koff] = b[ldb * i + j] + BitFlipValue; + else + bp[16 * cell + 4 * noff + koff] = 0; + } + } + + // HIGH EDGE IN K DIRECTION + for (i = k8; i < k_cap; ++i) { + for (j = 0; j < n8; ++j) { + kcell = i >> 2; + // special handling if j is in a PARTIAL last "group of 8" + int32_t maingroup = (j & (~7)) < (n & (~7)); + int32_t columns_done = ((j & (~7)) >> 3) << 1; + int32_t groupwidth = (maingroup || ((n & 7) > 4)) ? 2 : 1; + int32_t j_hiflag = (j & 4) >> 2; + cell = columns_done * krows + kcell * groupwidth + j_hiflag; + koff = i & 3; + noff = j & 3; + if (i < k) + bp[16 * cell + 4 * noff + koff] = b[ldb * i + j] + BitFlipValue; + else + bp[16 * cell + 4 * noff + koff] = 0; + } + } + + // UPPER CORNER (HIGH N, HIGH K) + for (i = k8; i < k_cap; ++i) { + for (j = n8; j < n_cap; ++j) { + kcell = i >> 2; + // special handling if j is in a PARTIAL last "group of 8" + int32_t maingroup = (j & (~7)) < (n & (~7)); + int32_t columns_done = ((j & (~7)) >> 3) << 1; + int32_t groupwidth = (maingroup || ((n & 7) > 4)) ? 2 : 1; + int32_t j_hiflag = (j & 4) >> 2; + cell = columns_done * krows + kcell * groupwidth + j_hiflag; + koff = i & 3; + noff = j & 3; + if (i < k && j < n) + bp[16 * cell + 4 * noff + koff] = b[ldb * i + j] + BitFlipValue; + else + bp[16 * cell + 4 * noff + koff] = 0; + } + } + + return 0; +} + +template +inline int packB_T8_8bit(dim_t k, dim_t n, const b_type *b, dim_t ldb, + uint8_t *bp, bool is_signed) { + + if (is_signed) { + pack_T8_8bit_V2_signed<__vector signed char, b_type>( + k, n, b, ldb, bp, true); + } else { + pack_T8_8bit_V2_signed<__vector unsigned char, b_type>( + k, n, b, ldb, bp, false); + } + return 0; +} + +inline int pack_T8_8bit( + dim_t k, dim_t n, const uint8_t *b, dim_t ldb, uint8_t *bp) { int32_t i, j; int32_t kcell, cell, koff, noff, krows, k8, n8; int32_t n_cap = (n + 3) & ~3; @@ -1406,7 +5870,7 @@ typedef __vector int32_t v4si_t __attribute__((aligned(4))); #define MMA __builtin_mma_xvi16ger2pp -void gemm_kernel_16bit(dim_t m, dim_t n, dim_t k, float alpha, short *A, +inline void gemm_kernel_16bit(dim_t m, dim_t n, dim_t k, float alpha, short *A, short *B, int32_t *C, float beta, dim_t ldc) { int32_t i; int32_t m_cap = (m + 3) & ~3; @@ -2240,8 +6704,7 @@ void gemm_kernel_16bit(dim_t m, dim_t n, dim_t k, float alpha, short *A, #undef MMA #define MMA __builtin_mma_xvi8ger4pp - -void gemm_kernel_8bit(dim_t m, dim_t n, dim_t k, float alpha, int8_t *A, +inline void gemm_kernel_8bit(dim_t m, dim_t n, dim_t k, float alpha, int8_t *A, uint8_t *B, int32_t *C, float beta, dim_t ldc) { int32_t i; int32_t m_cap = (m + 3) & ~3; diff --git a/src/cpu/reorder/cpu_reorder.hpp b/src/cpu/reorder/cpu_reorder.hpp index e8a1795b76b..b8517344709 100644 --- a/src/cpu/reorder/cpu_reorder.hpp +++ b/src/cpu/reorder/cpu_reorder.hpp @@ -41,6 +41,8 @@ #if defined(DNNL_AARCH64_USE_ACL) #include "cpu/aarch64/acl_reorder.hpp" #endif +#elif DNNL_PPC64 +#include "cpu/ppc64/ppc64_gemm_reorder.hpp" #endif #include "cpu/rnn/rnn_reorders.hpp" diff --git a/src/cpu/reorder/cpu_reorder_regular_f32_u8.cpp b/src/cpu/reorder/cpu_reorder_regular_f32_u8.cpp index 12ed3556dae..99da33a193c 100644 --- a/src/cpu/reorder/cpu_reorder_regular_f32_u8.cpp +++ b/src/cpu/reorder/cpu_reorder_regular_f32_u8.cpp @@ -36,6 +36,8 @@ const impl_list_map_t ®ular_f32_u8_impl_list_map() { DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) + DNNL_PPC64_ONLY(CPU_REORDER_INSTANCE(ppc64::ppc64_matrixA_reorder_t)) + REG_FAST_DIRECT_COPY(f32, u8) DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, u8, nChw16c))