Skip to content

Commit

Permalink
cpu: x64: matmul: enable int4 weights decompression
Browse files Browse the repository at this point in the history
  • Loading branch information
xuxinzen committed Jun 26, 2024
1 parent 75b3946 commit 154acd4
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 18 deletions.
8 changes: 5 additions & 3 deletions src/cpu/x64/matmul/brgemm_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
= everyone_is(bf16, src_dt, wei_dt) && one_of(dst_dt, bf16, f32);
const bool is_f16
= everyone_is(f16, src_dt, wei_dt) && one_of(dst_dt, f16, f32);
const bool is_bf16_with_int_wei = src_dt == bf16 && one_of(wei_dt, s8, u8)
&& one_of(dst_dt, bf16, f32);
const bool is_bf16_with_int_wei = src_dt == bf16
&& one_of(wei_dt, s8, u8, s4, u4) && one_of(dst_dt, bf16, f32);

auto check_bias = [&]() -> bool {
const auto bia_dt = weights_md(1)->data_type;
Expand Down Expand Up @@ -1126,7 +1126,9 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
}

int cur_b = get_bb_idx(b, bgmmc_.bcast_B_desc);
return data_B_ptr_ + get_data_B_off(cur_b, k, n);
const dim_t B_off = get_data_B_off(cur_b, k, n);
assert(IMPLICATION(bgmmc_.is_int4_weights, B_off % 2 == 0));
return data_B_ptr_ + (bgmmc_.is_int4_weights ? B_off / 2 : B_off);
}

const char *get_data_B_bitmask_ptr(int b, int k, int n) const {
Expand Down
99 changes: 92 additions & 7 deletions src/cpu/x64/matmul/brgemm_matmul_copy_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2831,6 +2831,7 @@ struct jit_brgemm_matmul_copy_b_bf16_t : public jit_brgemm_matmul_copy_b_t,
, src_stride(conf->copy_B_wei_stride)
, tr_src_stride(conf_->LDB * k_blk_step * tr_typesize)
, scales_N_stride(conf_->N * scales_typesize)
, is_src_int4(one_of(conf->orig_wei_dt, data_type::s4, data_type::u4))
, is_dynamic_stride(is_runtime_value(src_stride))
, is_dynamic_N(conf->is_runtime_N)
, req_cvtps2bf16(conf->is_bf32 || conf->is_bf16_with_int_wei)
Expand All @@ -2846,10 +2847,12 @@ struct jit_brgemm_matmul_copy_b_bf16_t : public jit_brgemm_matmul_copy_b_t,
using opmask_t = const Xbyak::Opmask;
using zmm = const Xbyak::Zmm;
using ymm = const Xbyak::Ymm;
using Vmm_lower_t = typename vreg_traits<Vmm>::Vmm_lower_t;

enum { k_blk_step = 2, n_blk_step = 16 };
const int typesize, tr_typesize, scales_typesize;
const dim_t src_stride, tr_src_stride, scales_N_stride;
const bool is_src_int4;
const bool is_dynamic_stride;
const bool is_dynamic_N;
const bool req_cvtps2bf16;
Expand All @@ -2863,6 +2866,9 @@ struct jit_brgemm_matmul_copy_b_bf16_t : public jit_brgemm_matmul_copy_b_t,

opmask_t kTail = k7;
opmask_t kFFFF = k6;
opmask_t kTail_int4 = k5;
opmask_t kAAAA = k4;
opmask_t kSign = k3;

reg64_t reg_src = rax;
reg64_t reg_tr_src = rbx;
Expand All @@ -2886,6 +2892,10 @@ struct jit_brgemm_matmul_copy_b_bf16_t : public jit_brgemm_matmul_copy_b_t,
Vmm vmm_permw = Vmm(1);
Vmm vmm_tmp = Vmm(1); // used only for avx2_vnni_2
Vmm vmm_zp_b_shift = Vmm(2);
Vmm vmm_permd = Vmm(3);
Vmm vmm_int4_mask = Vmm(4);
Vmm vmm_sign_bit = Vmm(5);
Vmm vmm_sign_mask = Vmm(6);

void kmovx(Opmask k, unsigned w) {
if (!isa_has_masks(conf_->isa)) return;
Expand All @@ -2901,30 +2911,86 @@ struct jit_brgemm_matmul_copy_b_bf16_t : public jit_brgemm_matmul_copy_b_t,
else
jit_generator::kmovd(k, regw_tmp);
}
void copy_half_int4(const Zmm &zmm, const Ymm &ymm_half) {
vinserti64x4(zmm, zmm, ymm_half, 1);
}
void copy_half_int4(const Ymm &ymm, const Xmm &xmm_half) {
vinserti128(ymm, ymm, xmm_half, 1);
}
Vmm_lower_t maybe_mask(Vmm_lower_t vmm_lower, bool is_tail) {
assert(is_src_int4);
if (isa_has_masks(conf_->isa)) {
return is_tail ? vmm_lower | kTail_int4 | T_z
: vmm_lower | kFFFF | T_z;
} else {
return vmm_lower;
}
}
Vmm maybe_mask(Vmm vmm, bool is_tail) {
if (isa_has_masks(conf_->isa)) {
return is_tail ? vmm | kTail | T_z : vmm | kFFFF | T_z;
} else {
return vmm;
}
}
void load_int(
const Vmm vmm_in, const Xbyak::Operand &op, bool is_tail = false);
void copy_block(int nrows, int ncolumns, bool n_tail);
void copy_2x32(int nrows, int ncolumns);
void init_masks();
void generate() override;
};

template <typename Vmm>
void jit_brgemm_matmul_copy_b_bf16_t<Vmm>::load_int(
const Vmm vmm_in, const Xbyak::Operand &op, bool is_tail) {
const auto vmm = maybe_mask(vmm_in, is_tail);
const auto vmm_lower = Vmm_lower_t(vmm.getIdx());
const auto is_s4 = conf_->orig_wei_dt == data_type::s4;
MAYBE_UNUSED(vmm_lower);
MAYBE_UNUSED(is_s4);

switch (conf_->orig_wei_dt) {
case data_type::s8: uni_vpmovsxbd(vmm, op); break;
case data_type::u8: uni_vpmovzxbd(vmm, op); break;
// For int4, we see two int4 as one int8 and extend them int32
// low half stores in lower bytes of vmm and high half in higher
// bytes of vmm, then permute them into correct order
// Finally, we process the extend bytes for s4/u4 accordingly
case data_type::s4:
case data_type::u4:
if (is_s4)
uni_vpmovsxbd(maybe_mask(vmm_lower, is_tail), op);
else
uni_vpmovzxbd(maybe_mask(vmm_lower, is_tail), op);
copy_half_int4(vmm_in, vmm_lower);
vpermd(vmm_in, vmm_permd, vmm_in);
vpsrld(vmm_in | kAAAA, vmm_in, 4);
if (is_s4) vptestmd(kSign, vmm_in, vmm_sign_bit);
vpandd(vmm_in, vmm_in, vmm_int4_mask);
if (is_s4) vpord(vmm_in | kSign, vmm_in, vmm_sign_mask);
break;
default: assert(!"unsupported data type");
}
}

template <typename Vmm>
void jit_brgemm_matmul_copy_b_bf16_t<Vmm>::copy_2x32(int nrows, int ncolumns) {

const int columns_tail = ncolumns % n_blk_step;
if (columns_tail > 0 && columns_tail < n_blk_step) {
const auto tail_mask = (1 << columns_tail) - 1;
kmovx(kTail, tail_mask);
if (is_src_int4) {
const auto int4_tail_mask = (1 << (columns_tail / 2)) - 1;
kmovx(kTail_int4, int4_tail_mask);
}
}

static constexpr int blk_sz = k_blk_step;
const int reserved_regs = req_zp_b_shift ? 3 : 2;
const int reserved_regs = !is_src_int4
? (req_zp_b_shift ? 3 : 2)
: (conf_->orig_wei_dt == data_type::s4 ? 7 : 5);
const int max_isa_regs = isa_num_vregs(conf_->isa);
const int max_regs_available = max_isa_regs - reserved_regs;
const int max_unroll = max_regs_available / blk_sz;
Expand All @@ -2941,8 +3007,9 @@ void jit_brgemm_matmul_copy_b_bf16_t<Vmm>::copy_2x32(int nrows, int ncolumns) {
auto src_reg = get_vmm(blk, k % k_blk_step);
const bool is_tail = ncolumns - n < n_blk_step;
auto src_load = maybe_mask(src_reg, is_tail);
const auto offset
= (is_dynamic_stride ? 0 : k * src_stride) + n * typesize;
const auto factor = is_src_int4 ? 2 : 1;
const auto offset = (is_dynamic_stride ? 0 : k * src_stride)
+ ((n * typesize) / factor);
const auto reg_src_load
= is_dynamic_stride && k % 2 != 0 ? reg_src_load_1 : reg_src;
auto load_addr = maybe_EVEX_compress_addr(reg_src_load, offset);
Expand All @@ -2955,10 +3022,7 @@ void jit_brgemm_matmul_copy_b_bf16_t<Vmm>::copy_2x32(int nrows, int ncolumns) {
if (conf_->is_bf32)
uni_vmovups(src_load, load_addr);
else if (conf_->is_bf16_with_int_wei) {
if (conf_->orig_wei_dt == data_type::s8)
uni_vpmovsxbd(src_load, load_addr);
else
uni_vpmovzxbd(src_load, load_addr);
load_int(src_reg, load_addr, is_tail);
if (req_zp_b_shift)
uni_vpsubd(src_load, src_load, vmm_zp_b_shift);
uni_vcvtdq2ps(src_load, src_load);
Expand Down Expand Up @@ -3054,6 +3118,27 @@ void jit_brgemm_matmul_copy_b_bf16_t<Vmm>::init_masks() {

mov(reg_tmp, reinterpret_cast<size_t>(bf16_vnni_permute));
vmovdqa64(vmm_permw, ptr[reg_tmp]);

if (is_src_int4) {
alignas(64) static constexpr const uint32_t int4_permute[16]
= {0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15};
mov(reg_tmp, reinterpret_cast<size_t>(int4_permute));
vmovdqa32(vmm_permd, ptr[reg_tmp]);

kmovx(kAAAA, 0xaaaa);

const auto reg32_scratch = reg_tmp.cvt32();
mov(reg32_scratch, 0xf);
vpbroadcastd(vmm_int4_mask, reg32_scratch);

if (conf_->orig_wei_dt == data_type::s4) {
mov(reg32_scratch, 0x8);
vpbroadcastd(vmm_sign_bit, reg32_scratch);

mov(reg32_scratch, 0xfffffff8);
vpbroadcastd(vmm_sign_mask, reg32_scratch);
}
}
}
}

Expand Down
20 changes: 15 additions & 5 deletions src/cpu/x64/matmul/brgemm_matmul_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,11 @@ brgemm_matmul_conf_utils_t::brgemm_matmul_conf_utils_t(
, bf32_dt(f32_dt
&& one_of(attr.fpmath_.mode_, fpmath_mode::bf16, fpmath_mode::any)
&& isa == avx512_core_amx)
, bf16_with_int_wei_dt(bgmmc.src_dt == bf16
&& utils::one_of(bgmmc.wei_dt, u8, s8)
&& one_of(bgmmc.dst_dt, bf16, f32))
, weights_decompression_support(one_of(bgmmc.wei_dt, u8, s8)
, weights_decompression_support(one_of(bgmmc.wei_dt, u8, s8, u4, s4)
&& one_of(attr.fpmath_.mode_, fpmath_mode::bf16, fpmath_mode::any)
&& attr.fpmath_.apply_to_int_)
, bf16_with_int_wei_dt(weights_decompression_support && bgmmc.src_dt == bf16
&& one_of(bgmmc.dst_dt, bf16, f32))
, A_any_layout(A_any_layout)
, B_any_layout(B_any_layout)
, C_any_layout(C_any_layout)
Expand Down Expand Up @@ -1190,6 +1189,7 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
bgmmc.is_bf32 = bm_conf_utils.is_bf32();
bgmmc.is_bf16_with_int_wei = bm_conf_utils.is_bf16_with_int_wei();
bgmmc.with_wei_decompression = bm_conf_utils.with_weights_decompression();
bgmmc.is_int4_weights = one_of(bgmmc.wei_dt, data_type::s4, data_type::u4);

// Make BRGeMM compute MatMul as if it were in bfloat16, while down-convert
// happens during copy-buffer computations
Expand Down Expand Up @@ -1325,6 +1325,12 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
&& bgmmc.is_oscale_per_k && bgmmc.is_oscale_per_n
&& bgmmc.transposed_B;

// int4 weights decompression only supports plain layout for now
// TODO: enable int4 reorder and extend support to other weight layouts
if (bgmmc.with_wei_decompression && bgmmc.is_int4_weights)
VCONDCHECK_BG(bm_conf_utils.check_is_plain(bgmmc.wei_tag),
VERBOSE_UNSUPPORTED_TAG);

const bool transposed_A = bm_conf_utils.check_is_transposed(bgmmc.src_tag);
// if M == 1 we can still treat formally transposed A as plain
// and avoid copy routine creation/execution
Expand Down Expand Up @@ -1679,10 +1685,14 @@ void init_aux_values(brgemm_matmul_conf_t &bgmmc,
&& wei_d.matches_one_of_tag(abcd) == format_tag::undef) {
bgmmc.copy_B_wei_stride = bgmmc.K * bgmmc.b_dt_sz;
} else {
const dim_t factor = bgmmc.is_int4_weights ? 2 : 1;
const auto b_stride_elems
= bgmmc.req_wei_vnni_downconvert ? bgmmc.LDB : bgmmc.N;
assert(IMPLICATION(bgmmc.is_int4_weights, b_stride_elems % 2 == 0));
bgmmc.copy_B_wei_stride
= bgmmc.is_runtime_N ? bgmmc.N : b_stride_elems * bgmmc.b_dt_sz;
= (bgmmc.is_runtime_N ? bgmmc.N
: b_stride_elems * bgmmc.b_dt_sz)
/ factor;
}

bgmmc.C_ptr_shift_b = dst_d.matches_one_of_tag(acbd)
Expand Down
8 changes: 5 additions & 3 deletions src/cpu/x64/matmul/brgemm_matmul_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ struct brgemm_matmul_conf_t {
int required_k_granularity;
bool is_bf32 = false;
bool is_bf16_with_int_wei = false;
bool is_int4_weights = false;
bool req_wei_vnni_downconvert = false;
bool is_runtime_M = false;
bool is_runtime_N = false;
Expand Down Expand Up @@ -288,7 +289,8 @@ struct brgemm_matmul_conf_utils_t {
inline bool is_bf16_with_int_wei() const { return bf16_with_int_wei_dt; }

inline bool with_weights_decompression() const {
return !utils::one_of(bgmmc.src_dt, data_type::s8, data_type::u8)
return !utils::one_of(bgmmc.src_dt, data_type::s8, data_type::u8,
data_type::s4, data_type::u4)
&& weights_decompression_support;
}

Expand Down Expand Up @@ -316,8 +318,8 @@ struct brgemm_matmul_conf_utils_t {
private:
brgemm_matmul_conf_t &bgmmc;

const bool f32_dt, bf16_dt, f16_dt, int8_dt, bf32_dt, bf16_with_int_wei_dt;
const bool weights_decompression_support;
const bool f32_dt, bf16_dt, f16_dt, int8_dt, bf32_dt;
const bool weights_decompression_support, bf16_with_int_wei_dt;
const bool A_any_layout;
const bool B_any_layout;
const bool C_any_layout;
Expand Down

0 comments on commit 154acd4

Please sign in to comment.