From 154acd40fa2e1cca16c98875f18015103db0ccb2 Mon Sep 17 00:00:00 2001 From: "Xuxin, Zeng" Date: Tue, 25 Jun 2024 14:28:48 -0700 Subject: [PATCH] cpu: x64: matmul: enable int4 weights decompression --- src/cpu/x64/matmul/brgemm_matmul.cpp | 8 +- .../x64/matmul/brgemm_matmul_copy_utils.cpp | 99 +++++++++++++++++-- src/cpu/x64/matmul/brgemm_matmul_utils.cpp | 20 +++- src/cpu/x64/matmul/brgemm_matmul_utils.hpp | 8 +- 4 files changed, 117 insertions(+), 18 deletions(-) diff --git a/src/cpu/x64/matmul/brgemm_matmul.cpp b/src/cpu/x64/matmul/brgemm_matmul.cpp index 441b2abe199..eb9d35a16fc 100644 --- a/src/cpu/x64/matmul/brgemm_matmul.cpp +++ b/src/cpu/x64/matmul/brgemm_matmul.cpp @@ -57,8 +57,8 @@ status_t brgemm_matmul_t::pd_t::init(engine_t *engine) { = everyone_is(bf16, src_dt, wei_dt) && one_of(dst_dt, bf16, f32); const bool is_f16 = everyone_is(f16, src_dt, wei_dt) && one_of(dst_dt, f16, f32); - const bool is_bf16_with_int_wei = src_dt == bf16 && one_of(wei_dt, s8, u8) - && one_of(dst_dt, bf16, f32); + const bool is_bf16_with_int_wei = src_dt == bf16 + && one_of(wei_dt, s8, u8, s4, u4) && one_of(dst_dt, bf16, f32); auto check_bias = [&]() -> bool { const auto bia_dt = weights_md(1)->data_type; @@ -1126,7 +1126,9 @@ struct brgemm_matmul_t::brg_matmul_exec_ctx_t { } int cur_b = get_bb_idx(b, bgmmc_.bcast_B_desc); - return data_B_ptr_ + get_data_B_off(cur_b, k, n); + const dim_t B_off = get_data_B_off(cur_b, k, n); + assert(IMPLICATION(bgmmc_.is_int4_weights, B_off % 2 == 0)); + return data_B_ptr_ + (bgmmc_.is_int4_weights ? B_off / 2 : B_off); } const char *get_data_B_bitmask_ptr(int b, int k, int n) const { diff --git a/src/cpu/x64/matmul/brgemm_matmul_copy_utils.cpp b/src/cpu/x64/matmul/brgemm_matmul_copy_utils.cpp index d77de3a25e3..b877d6369a8 100644 --- a/src/cpu/x64/matmul/brgemm_matmul_copy_utils.cpp +++ b/src/cpu/x64/matmul/brgemm_matmul_copy_utils.cpp @@ -2831,6 +2831,7 @@ struct jit_brgemm_matmul_copy_b_bf16_t : public jit_brgemm_matmul_copy_b_t, , src_stride(conf->copy_B_wei_stride) , tr_src_stride(conf_->LDB * k_blk_step * tr_typesize) , scales_N_stride(conf_->N * scales_typesize) + , is_src_int4(one_of(conf->orig_wei_dt, data_type::s4, data_type::u4)) , is_dynamic_stride(is_runtime_value(src_stride)) , is_dynamic_N(conf->is_runtime_N) , req_cvtps2bf16(conf->is_bf32 || conf->is_bf16_with_int_wei) @@ -2846,10 +2847,12 @@ struct jit_brgemm_matmul_copy_b_bf16_t : public jit_brgemm_matmul_copy_b_t, using opmask_t = const Xbyak::Opmask; using zmm = const Xbyak::Zmm; using ymm = const Xbyak::Ymm; + using Vmm_lower_t = typename vreg_traits::Vmm_lower_t; enum { k_blk_step = 2, n_blk_step = 16 }; const int typesize, tr_typesize, scales_typesize; const dim_t src_stride, tr_src_stride, scales_N_stride; + const bool is_src_int4; const bool is_dynamic_stride; const bool is_dynamic_N; const bool req_cvtps2bf16; @@ -2863,6 +2866,9 @@ struct jit_brgemm_matmul_copy_b_bf16_t : public jit_brgemm_matmul_copy_b_t, opmask_t kTail = k7; opmask_t kFFFF = k6; + opmask_t kTail_int4 = k5; + opmask_t kAAAA = k4; + opmask_t kSign = k3; reg64_t reg_src = rax; reg64_t reg_tr_src = rbx; @@ -2886,6 +2892,10 @@ struct jit_brgemm_matmul_copy_b_bf16_t : public jit_brgemm_matmul_copy_b_t, Vmm vmm_permw = Vmm(1); Vmm vmm_tmp = Vmm(1); // used only for avx2_vnni_2 Vmm vmm_zp_b_shift = Vmm(2); + Vmm vmm_permd = Vmm(3); + Vmm vmm_int4_mask = Vmm(4); + Vmm vmm_sign_bit = Vmm(5); + Vmm vmm_sign_mask = Vmm(6); void kmovx(Opmask k, unsigned w) { if (!isa_has_masks(conf_->isa)) return; @@ -2901,6 +2911,21 @@ struct jit_brgemm_matmul_copy_b_bf16_t : public jit_brgemm_matmul_copy_b_t, else jit_generator::kmovd(k, regw_tmp); } + void copy_half_int4(const Zmm &zmm, const Ymm &ymm_half) { + vinserti64x4(zmm, zmm, ymm_half, 1); + } + void copy_half_int4(const Ymm &ymm, const Xmm &xmm_half) { + vinserti128(ymm, ymm, xmm_half, 1); + } + Vmm_lower_t maybe_mask(Vmm_lower_t vmm_lower, bool is_tail) { + assert(is_src_int4); + if (isa_has_masks(conf_->isa)) { + return is_tail ? vmm_lower | kTail_int4 | T_z + : vmm_lower | kFFFF | T_z; + } else { + return vmm_lower; + } + } Vmm maybe_mask(Vmm vmm, bool is_tail) { if (isa_has_masks(conf_->isa)) { return is_tail ? vmm | kTail | T_z : vmm | kFFFF | T_z; @@ -2908,12 +2933,47 @@ struct jit_brgemm_matmul_copy_b_bf16_t : public jit_brgemm_matmul_copy_b_t, return vmm; } } + void load_int( + const Vmm vmm_in, const Xbyak::Operand &op, bool is_tail = false); void copy_block(int nrows, int ncolumns, bool n_tail); void copy_2x32(int nrows, int ncolumns); void init_masks(); void generate() override; }; +template +void jit_brgemm_matmul_copy_b_bf16_t::load_int( + const Vmm vmm_in, const Xbyak::Operand &op, bool is_tail) { + const auto vmm = maybe_mask(vmm_in, is_tail); + const auto vmm_lower = Vmm_lower_t(vmm.getIdx()); + const auto is_s4 = conf_->orig_wei_dt == data_type::s4; + MAYBE_UNUSED(vmm_lower); + MAYBE_UNUSED(is_s4); + + switch (conf_->orig_wei_dt) { + case data_type::s8: uni_vpmovsxbd(vmm, op); break; + case data_type::u8: uni_vpmovzxbd(vmm, op); break; + // For int4, we see two int4 as one int8 and extend them int32 + // low half stores in lower bytes of vmm and high half in higher + // bytes of vmm, then permute them into correct order + // Finally, we process the extend bytes for s4/u4 accordingly + case data_type::s4: + case data_type::u4: + if (is_s4) + uni_vpmovsxbd(maybe_mask(vmm_lower, is_tail), op); + else + uni_vpmovzxbd(maybe_mask(vmm_lower, is_tail), op); + copy_half_int4(vmm_in, vmm_lower); + vpermd(vmm_in, vmm_permd, vmm_in); + vpsrld(vmm_in | kAAAA, vmm_in, 4); + if (is_s4) vptestmd(kSign, vmm_in, vmm_sign_bit); + vpandd(vmm_in, vmm_in, vmm_int4_mask); + if (is_s4) vpord(vmm_in | kSign, vmm_in, vmm_sign_mask); + break; + default: assert(!"unsupported data type"); + } +} + template void jit_brgemm_matmul_copy_b_bf16_t::copy_2x32(int nrows, int ncolumns) { @@ -2921,10 +2981,16 @@ void jit_brgemm_matmul_copy_b_bf16_t::copy_2x32(int nrows, int ncolumns) { if (columns_tail > 0 && columns_tail < n_blk_step) { const auto tail_mask = (1 << columns_tail) - 1; kmovx(kTail, tail_mask); + if (is_src_int4) { + const auto int4_tail_mask = (1 << (columns_tail / 2)) - 1; + kmovx(kTail_int4, int4_tail_mask); + } } static constexpr int blk_sz = k_blk_step; - const int reserved_regs = req_zp_b_shift ? 3 : 2; + const int reserved_regs = !is_src_int4 + ? (req_zp_b_shift ? 3 : 2) + : (conf_->orig_wei_dt == data_type::s4 ? 7 : 5); const int max_isa_regs = isa_num_vregs(conf_->isa); const int max_regs_available = max_isa_regs - reserved_regs; const int max_unroll = max_regs_available / blk_sz; @@ -2941,8 +3007,9 @@ void jit_brgemm_matmul_copy_b_bf16_t::copy_2x32(int nrows, int ncolumns) { auto src_reg = get_vmm(blk, k % k_blk_step); const bool is_tail = ncolumns - n < n_blk_step; auto src_load = maybe_mask(src_reg, is_tail); - const auto offset - = (is_dynamic_stride ? 0 : k * src_stride) + n * typesize; + const auto factor = is_src_int4 ? 2 : 1; + const auto offset = (is_dynamic_stride ? 0 : k * src_stride) + + ((n * typesize) / factor); const auto reg_src_load = is_dynamic_stride && k % 2 != 0 ? reg_src_load_1 : reg_src; auto load_addr = maybe_EVEX_compress_addr(reg_src_load, offset); @@ -2955,10 +3022,7 @@ void jit_brgemm_matmul_copy_b_bf16_t::copy_2x32(int nrows, int ncolumns) { if (conf_->is_bf32) uni_vmovups(src_load, load_addr); else if (conf_->is_bf16_with_int_wei) { - if (conf_->orig_wei_dt == data_type::s8) - uni_vpmovsxbd(src_load, load_addr); - else - uni_vpmovzxbd(src_load, load_addr); + load_int(src_reg, load_addr, is_tail); if (req_zp_b_shift) uni_vpsubd(src_load, src_load, vmm_zp_b_shift); uni_vcvtdq2ps(src_load, src_load); @@ -3054,6 +3118,27 @@ void jit_brgemm_matmul_copy_b_bf16_t::init_masks() { mov(reg_tmp, reinterpret_cast(bf16_vnni_permute)); vmovdqa64(vmm_permw, ptr[reg_tmp]); + + if (is_src_int4) { + alignas(64) static constexpr const uint32_t int4_permute[16] + = {0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15}; + mov(reg_tmp, reinterpret_cast(int4_permute)); + vmovdqa32(vmm_permd, ptr[reg_tmp]); + + kmovx(kAAAA, 0xaaaa); + + const auto reg32_scratch = reg_tmp.cvt32(); + mov(reg32_scratch, 0xf); + vpbroadcastd(vmm_int4_mask, reg32_scratch); + + if (conf_->orig_wei_dt == data_type::s4) { + mov(reg32_scratch, 0x8); + vpbroadcastd(vmm_sign_bit, reg32_scratch); + + mov(reg32_scratch, 0xfffffff8); + vpbroadcastd(vmm_sign_mask, reg32_scratch); + } + } } } diff --git a/src/cpu/x64/matmul/brgemm_matmul_utils.cpp b/src/cpu/x64/matmul/brgemm_matmul_utils.cpp index ed4ac58c151..c9fe32bfb38 100644 --- a/src/cpu/x64/matmul/brgemm_matmul_utils.cpp +++ b/src/cpu/x64/matmul/brgemm_matmul_utils.cpp @@ -193,12 +193,11 @@ brgemm_matmul_conf_utils_t::brgemm_matmul_conf_utils_t( , bf32_dt(f32_dt && one_of(attr.fpmath_.mode_, fpmath_mode::bf16, fpmath_mode::any) && isa == avx512_core_amx) - , bf16_with_int_wei_dt(bgmmc.src_dt == bf16 - && utils::one_of(bgmmc.wei_dt, u8, s8) - && one_of(bgmmc.dst_dt, bf16, f32)) - , weights_decompression_support(one_of(bgmmc.wei_dt, u8, s8) + , weights_decompression_support(one_of(bgmmc.wei_dt, u8, s8, u4, s4) && one_of(attr.fpmath_.mode_, fpmath_mode::bf16, fpmath_mode::any) && attr.fpmath_.apply_to_int_) + , bf16_with_int_wei_dt(weights_decompression_support && bgmmc.src_dt == bf16 + && one_of(bgmmc.dst_dt, bf16, f32)) , A_any_layout(A_any_layout) , B_any_layout(B_any_layout) , C_any_layout(C_any_layout) @@ -1190,6 +1189,7 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc, bgmmc.is_bf32 = bm_conf_utils.is_bf32(); bgmmc.is_bf16_with_int_wei = bm_conf_utils.is_bf16_with_int_wei(); bgmmc.with_wei_decompression = bm_conf_utils.with_weights_decompression(); + bgmmc.is_int4_weights = one_of(bgmmc.wei_dt, data_type::s4, data_type::u4); // Make BRGeMM compute MatMul as if it were in bfloat16, while down-convert // happens during copy-buffer computations @@ -1325,6 +1325,12 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc, && bgmmc.is_oscale_per_k && bgmmc.is_oscale_per_n && bgmmc.transposed_B; + // int4 weights decompression only supports plain layout for now + // TODO: enable int4 reorder and extend support to other weight layouts + if (bgmmc.with_wei_decompression && bgmmc.is_int4_weights) + VCONDCHECK_BG(bm_conf_utils.check_is_plain(bgmmc.wei_tag), + VERBOSE_UNSUPPORTED_TAG); + const bool transposed_A = bm_conf_utils.check_is_transposed(bgmmc.src_tag); // if M == 1 we can still treat formally transposed A as plain // and avoid copy routine creation/execution @@ -1679,10 +1685,14 @@ void init_aux_values(brgemm_matmul_conf_t &bgmmc, && wei_d.matches_one_of_tag(abcd) == format_tag::undef) { bgmmc.copy_B_wei_stride = bgmmc.K * bgmmc.b_dt_sz; } else { + const dim_t factor = bgmmc.is_int4_weights ? 2 : 1; const auto b_stride_elems = bgmmc.req_wei_vnni_downconvert ? bgmmc.LDB : bgmmc.N; + assert(IMPLICATION(bgmmc.is_int4_weights, b_stride_elems % 2 == 0)); bgmmc.copy_B_wei_stride - = bgmmc.is_runtime_N ? bgmmc.N : b_stride_elems * bgmmc.b_dt_sz; + = (bgmmc.is_runtime_N ? bgmmc.N + : b_stride_elems * bgmmc.b_dt_sz) + / factor; } bgmmc.C_ptr_shift_b = dst_d.matches_one_of_tag(acbd) diff --git a/src/cpu/x64/matmul/brgemm_matmul_utils.hpp b/src/cpu/x64/matmul/brgemm_matmul_utils.hpp index d9868e7fc72..ad6552e67b8 100644 --- a/src/cpu/x64/matmul/brgemm_matmul_utils.hpp +++ b/src/cpu/x64/matmul/brgemm_matmul_utils.hpp @@ -188,6 +188,7 @@ struct brgemm_matmul_conf_t { int required_k_granularity; bool is_bf32 = false; bool is_bf16_with_int_wei = false; + bool is_int4_weights = false; bool req_wei_vnni_downconvert = false; bool is_runtime_M = false; bool is_runtime_N = false; @@ -288,7 +289,8 @@ struct brgemm_matmul_conf_utils_t { inline bool is_bf16_with_int_wei() const { return bf16_with_int_wei_dt; } inline bool with_weights_decompression() const { - return !utils::one_of(bgmmc.src_dt, data_type::s8, data_type::u8) + return !utils::one_of(bgmmc.src_dt, data_type::s8, data_type::u8, + data_type::s4, data_type::u4) && weights_decompression_support; } @@ -316,8 +318,8 @@ struct brgemm_matmul_conf_utils_t { private: brgemm_matmul_conf_t &bgmmc; - const bool f32_dt, bf16_dt, f16_dt, int8_dt, bf32_dt, bf16_with_int_wei_dt; - const bool weights_decompression_support; + const bool f32_dt, bf16_dt, f16_dt, int8_dt, bf32_dt; + const bool weights_decompression_support, bf16_with_int_wei_dt; const bool A_any_layout; const bool B_any_layout; const bool C_any_layout;