Skip to content

Commit

Permalink
x64: brdgmm conv: enable zps per group
Browse files Browse the repository at this point in the history
  • Loading branch information
tczeszun committed Dec 16, 2024
1 parent 770f8ef commit 0741e51
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 42 deletions.
1 change: 1 addition & 0 deletions src/common/primitive_attr_quant.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ struct zero_points_t : public c_compatible {

// arg-specific checks
bool common(int arg) const { return get_mask(arg) == 0; }
bool per_oc(int arg) const { return get_mask(arg) == 2; }
bool has_default_values(int arg) const {
return is_set(arg) == false && has_default_data_type(arg);
}
Expand Down
14 changes: 9 additions & 5 deletions src/cpu/x64/brgemm/brgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs,
brgemm_p.b_zp_compensations = post_ops_data.b_zp_compensations;
brgemm_p.c_zp_values = post_ops_data.c_zp_values;
brgemm_p.ptr_dst_scales = post_ops_data.dst_scales;
brgemm_p.a_zp_values = post_ops_data.a_zp_values;
if (dynamic_values) {
brgemm_p.dynamic_LDA = dynamic_values->dynamic_LDA;
brgemm_p.dynamic_LDB = dynamic_values->dynamic_LDB;
Expand Down Expand Up @@ -458,19 +459,22 @@ status_t brgemm_desc_set_postops(brgemm_desc_t *brg,
auto zero_points = attr->zero_points_;

// common zero point type is supported for now
if (!zero_points.common(mem_arg)) return status::unimplemented;
const bool is_per_oc_bcast = zero_points.per_oc(mem_arg);
if (!zero_points.common(mem_arg) && !is_per_oc_bcast)
return status::unimplemented;

const bool skip_zero_point
= mem_arg == DNNL_ARG_WEIGHTS && brg->skip_zp_b_compensation;
zp_type = zero_points.has_default_values(mem_arg) || skip_zero_point
? brgemm_broadcast_t::none
: brgemm_broadcast_t::per_tensor;
: is_per_oc_bcast ? brgemm_broadcast_t::per_n
: brgemm_broadcast_t::per_tensor;
return status::success;
};

init_zp_type(brg->zp_type_a, DNNL_ARG_SRC);
init_zp_type(brg->zp_type_b, DNNL_ARG_WEIGHTS);
init_zp_type(brg->zp_type_c, DNNL_ARG_DST);
CHECK(init_zp_type(brg->zp_type_a, DNNL_ARG_SRC));
CHECK(init_zp_type(brg->zp_type_b, DNNL_ARG_WEIGHTS));
CHECK(init_zp_type(brg->zp_type_c, DNNL_ARG_DST));

// Post-ops may use vector registers so brgemm/brdgmm blocking may need to
// be updated
Expand Down
8 changes: 6 additions & 2 deletions src/cpu/x64/brgemm/brgemm_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,7 @@ struct brgemm_kernel_params_t {
const void *a_zp_compensations = nullptr;
const void *b_zp_compensations = nullptr;
const void *c_zp_values = nullptr;
const void *a_zp_values = nullptr;
size_t skip_accm = 0;
int32_t zp_a_val = 1;
const void *ptr_dst_scales = nullptr;
Expand Down Expand Up @@ -600,7 +601,8 @@ struct brgemm_post_ops_data_t {
const void *b_zp_compensations = nullptr,
const void *c_zp_values = nullptr, bool skip_accumulation = false,
int32_t zp_a_val = 1, bool do_only_comp = false,
bool do_only_zp_a_val = false, const float *dst_scales = nullptr)
bool do_only_zp_a_val = false, const float *dst_scales = nullptr,
const void *a_zp_values = nullptr)
: bias(bias)
, scales(scales)
, binary_post_ops_rhs(binary_post_ops_rhs)
Expand All @@ -615,7 +617,8 @@ struct brgemm_post_ops_data_t {
, zp_a_val {zp_a_val}
, do_only_comp {do_only_comp}
, do_only_zp_a_val {do_only_zp_a_val}
, dst_scales(dst_scales) {}
, dst_scales(dst_scales)
, a_zp_values(a_zp_values) {}

const void *bias = nullptr;
const float *scales = nullptr;
Expand All @@ -632,6 +635,7 @@ struct brgemm_post_ops_data_t {
const bool do_only_comp = false;
const bool do_only_zp_a_val = false;
const float *dst_scales = nullptr;
const void *a_zp_values = nullptr;
};

} // namespace x64
Expand Down
120 changes: 90 additions & 30 deletions src/cpu/x64/brgemm/jit_brdgmm_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ jit_brdgmm_kernel_base_t<Wmm>::jit_brdgmm_kernel_base_t(
, max_vmms_(isa_num_vregs(brg.isa_impl))
, compute_dst_zp_(brg.zp_type_c != brgemm_broadcast_t::none)
, compute_src_zp_(brg.zp_type_a != brgemm_broadcast_t::none)
, is_src_zp_bcast_(brg.zp_type_a == brgemm_broadcast_t::per_tensor)
, compute_compensation_(compute_src_zp_ || brg.req_s8s8_compensation)
, has_vpad_(brg.brgattr.max_top_vpad > 0 || brg.brgattr.max_bottom_vpad > 0)
, has_bpad_(brg.brgattr.max_top_bpad > 0 || brg.brgattr.max_bottom_bpad > 0)
Expand Down Expand Up @@ -147,7 +148,7 @@ void jit_brdgmm_kernel_base_t<Wmm>::read_params() {
}

if (compute_src_zp_) {
mov(reg_tmp, ptr[param1 + GET_OFF(zp_a_val)]);
mov(reg_tmp, ptr[param1 + GET_OFF(a_zp_values)]);
mov(ptr[rsp + src_zp_value_], reg_tmp);

mov(reg_tmp, ptr[param1 + GET_OFF(a_zp_compensations)]);
Expand Down Expand Up @@ -609,6 +610,17 @@ void jit_brdgmm_kernel_base_t<Wmm>::maybe_transpose_interleaved_vnni_to_plain(
}
}

template <typename Wmm>
void jit_brdgmm_kernel_base_t<Wmm>::load_src_zp() {
mov(reg_src_zero_point, ptr[rsp + src_zp_value_]);
lea(reg_src_zero_point,
is_src_zp_bcast_
? ptr_b[reg_src_zero_point]
: ptr[reg_src_zero_point + reg_aux_N * sizeof(int32_t)]);
if (!is_superset(brg.isa_impl, avx512_core) && is_src_zp_bcast_)
uni_vpbroadcastd(vmm_bcast(), ptr[reg_src_zero_point]);
}

template <typename Wmm>
void jit_brdgmm_kernel_base_t<Wmm>::compute_int8_compensation(
int m_blocks, int n_blocks, bool has_n_tail) {
Expand All @@ -620,12 +632,10 @@ void jit_brdgmm_kernel_base_t<Wmm>::compute_int8_compensation(
lea(reg_s8s8_comp, ptr[reg_s8s8_comp + reg_aux_N * sizeof(int32_t)]);
}
if (compute_src_zp_) {
lea(reg_src_zero_point, ptr[rsp + src_zp_value_]);
load_src_zp();
mov(reg_zp_compensation, ptr[rsp + zp_compensation_]);
lea(reg_zp_compensation,
ptr[reg_zp_compensation + reg_aux_N * sizeof(int32_t)]);
if (!is_superset(brg.isa_impl, avx512_core))
uni_vpbroadcastd(vmm_bcast(), ptr[reg_src_zero_point]);
}

for_(int v_i = 0; v_i < v_substep; ++v_i)
Expand All @@ -640,16 +650,35 @@ void jit_brdgmm_kernel_base_t<Wmm>::compute_int8_compensation(
}
if (compute_src_zp_) {
// zero_point: conv(src_x8, wei_s8) - src_shift_s32 * compensation_s32
const Vmm vmm_zp = vmm_zp_comp();
vmovups(vmm_zp,
maybe_EVEX_compress_addr(reg_zp_compensation, offset));
if (is_superset(brg.isa_impl, avx512_core)) {
const bool src_zp_is_common = true;
vpmulld(vmm_zp, vmm_zp,
maybe_EVEX_compress_addr(
reg_src_zero_point, 0, src_zp_is_common));
const bool is_tail
= n + 1 == n_blocks && has_n_tail && substep_simd < simd_w_;
const Vmm vmm_zp = isa_has_masks(brg.isa_impl)
? maybe_mask(vmm_zp_comp(), is_tail, false)
: vmm_zp_comp();
if (IMPLICATION(is_tail, isa_has_masks(brg.isa_impl))) {
vmovups(vmm_zp,
maybe_EVEX_compress_addr(reg_zp_compensation, offset));
if (is_src_zp_bcast_) {
if (is_superset(brg.isa_impl, avx512_core))
vpmulld(vmm_zp, vmm_zp,
maybe_EVEX_compress_addr(
reg_src_zero_point, 0, true));
else
vpmulld(vmm_zp, vmm_zp, vmm_bcast());
} else
vpmulld(vmm_zp, vmm_zp,
maybe_EVEX_compress_addr(
reg_src_zero_point, offset));
} else {
vpmulld(vmm_zp, vmm_zp, vmm_bcast());
const int tail_size = tail_length();
const Vmm ymm_tmp
= vmm_bcast(); // used for bcast or tail processing in avx2
load_data(data_type::s32, vmm_zp,
ptr[reg_zp_compensation + offset], tail_size);
if (!is_src_zp_bcast_)
load_data(data_type::s32, ymm_tmp,
ptr[reg_src_zero_point + offset], tail_size);
vpmulld(vmm_zp, vmm_zp, ymm_tmp);
}
}
for (int m = 0; m < m_blocks; m++) {
Expand Down Expand Up @@ -795,24 +824,48 @@ void jit_brdgmm_kernel_base_t<Wmm>::load_b(

template <typename Wmm>
void jit_brdgmm_kernel_base_t<Wmm>::comp_dot_product(
compute_pad_kernel_t kernel_type, Vmm vmm_acc, Vmm vmmb) {
compute_pad_kernel_t kernel_type, Vmm vmm_acc, Vmm vmmb, int n,
bool is_tail_block) {
switch (kernel_type) {
case compute_pad_kernel_t::s8s8_kernel:
vpdpbusd(vmm_acc, vmm_shift(), vmmb,
is_superset(brg.isa_impl, avx512_core)
? Xbyak::EvexEncoding
: Xbyak::VexEncoding);
break;
case compute_pad_kernel_t::zero_point_kernel:
if (is_superset(brg.isa_impl, avx512_core)) {
vpmulld(vmm_zp_comp(), vmmb,
maybe_EVEX_compress_addr(reg_src_zero_point, 0, true));
case compute_pad_kernel_t::zero_point_kernel: {
const Vmm vmm_zp = isa_has_masks(brg.isa_impl)
? maybe_mask(vmm_zp_comp(), is_tail_block, false)
: vmm_zp_comp();
const size_t offset = comp_offset(n);
if (IMPLICATION(is_tail_block, isa_has_masks(brg.isa_impl))) {
if (is_src_zp_bcast_) {
if (is_superset(brg.isa_impl, avx512_core))
vpmulld(vmm_zp, vmmb,
maybe_EVEX_compress_addr(
reg_src_zero_point, 0, true));
else
vpmulld(vmm_zp, vmmb, vmm_bcast());
} else {
const Xbyak::Address src_zp_addr = maybe_EVEX_compress_addr(
reg_src_zero_point, offset);
if (is_fast_vnni_int8()) {
vmovups(vmm_zp, src_zp_addr);
vpermd(vmm_zp, vmm_permute(), vmm_zp);
vpmulld(vmm_zp, vmmb, vmm_zp);
} else
vpmulld(vmm_zp, vmmb, src_zp_addr);
}
} else {
uni_vpbroadcastd(vmm_bcast(), ptr[reg_src_zero_point]);
vpmulld(vmm_zp_comp(), vmmb, vmm_bcast());
const Vmm ymm_tmp
= vmm_bcast(); // used for bcast or tail processing in avx2
if (!is_src_zp_bcast_)
load_data(data_type::s32, ymm_tmp,
ptr[reg_src_zero_point + offset], tail_length());
vpmulld(vmm_zp, vmmb, ymm_tmp);
}
vpaddd(vmm_acc, vmm_acc, vmm_zp_comp());
break;
} break;
default: assert(!"unsupported comp_kernel type");
}
}
Expand Down Expand Up @@ -853,21 +906,25 @@ void jit_brdgmm_kernel_base_t<Wmm>::pad_comp_kernel(

for (int pad_i = max_m_unroll; pad_i > 0; --pad_i) {
L(jmp_table_labels[pad_i]);
if (is_zero_point_kernel)
lea(reg_src_zero_point, ptr[rsp + src_zp_value_]);
if (is_zero_point_kernel) load_src_zp();
if (pad_i > m_blocks) continue;
const int m_i = get_mi(pad_i);
int p_b_i = 0;
for (int n_i = 0; n_i < n_blocks; ++n_i, ++p_b_i) {
if (get_substep_simd(n_i, 0, has_tail) <= 0) continue;
const int substep_simd = get_substep_simd(n_i, 0, has_tail);
if (substep_simd <= 0) continue;
const Vmm vmm_acc = accm(m_blocks, n_blocks, m_i, n_i, 0);
const bool is_tail_block
= n_i + 1 == n_blocks && has_tail && substep_simd < simd_w_;
if (p_b_i < n_preload_b_vmms) {
comp_dot_product(kernel_type, vmm_acc, vmm_b(p_b_i));
comp_dot_product(
kernel_type, vmm_acc, vmm_b(p_b_i), n_i, is_tail_block);
} else {
// preloaded vmm_b not available
const Vmm vmm_wei = vmm_b(max_bvmms - 1);
load_b(vmm_wei, n_i, 0, has_tail, load_broadcast_wei);
comp_dot_product(kernel_type, vmm_acc, vmm_wei);
comp_dot_product(
kernel_type, vmm_acc, vmm_wei, n_i, is_tail_block);
}
}
}
Expand All @@ -885,8 +942,7 @@ void jit_brdgmm_kernel_base_t<Wmm>::batch_pad_kernel(
auto kernel_body = [&](compute_pad_kernel_t kernel_type) {
const bool is_zero_point_kernel
= kernel_type == compute_pad_kernel_t::zero_point_kernel;
if (is_zero_point_kernel)
lea(reg_src_zero_point, ptr[rsp + src_zp_value_]);
if (is_zero_point_kernel) load_src_zp();
for (int nb_i = 0; nb_i < n_blocks; nb_i += max_bvmms) {
const int n_e = nstl::min(nb_i + max_bvmms, n_blocks) - nb_i;
for (int i = 0; i < n_e; ++i) {
Expand All @@ -898,9 +954,13 @@ void jit_brdgmm_kernel_base_t<Wmm>::batch_pad_kernel(
for_(int m_i = 0; m_i < m_blocks; ++m_i)
for (int i = 0; i < n_e; ++i) {
const int n_i = nb_i + i;
if (get_substep_simd(n_i, 0, has_tail) <= 0) continue;
const int substep_simd = get_substep_simd(n_i, 0, has_tail);
if (substep_simd <= 0) continue;
const Vmm vmm_acc = accm(m_blocks, n_blocks, m_i, n_i, 0);
comp_dot_product(kernel_type, vmm_acc, vmm_b(i));
const bool is_tail_block
= n_i + 1 == n_e && has_tail && substep_simd < simd_w_;
comp_dot_product(
kernel_type, vmm_acc, vmm_b(i), n_i, is_tail_block);
}
}
};
Expand Down
5 changes: 4 additions & 1 deletion src/cpu/x64/brgemm/jit_brdgmm_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ struct jit_brdgmm_kernel_base_t : public jit_generator {
const int simd_w_;
const int max_vmms_;
const bool compute_dst_zp_, compute_src_zp_;
const bool is_src_zp_bcast_;
const bool compute_compensation_; // code-path for either s8s8 or src_zp
const bool has_vpad_; // vertical padding w.r.t. M dimension
const bool has_bpad_; // batch pad is computed for the overlap between the
Expand Down Expand Up @@ -341,7 +342,8 @@ struct jit_brdgmm_kernel_base_t : public jit_generator {
void load_b(
Vmm vmmb, int n_i, int v_i, bool has_n_tail, bool wei_zp = false);
void comp_dot_product(compute_pad_kernel_t kernel_type, Vmm vmm_acc,
Vmm vmmb); // int8 compensation dot_product (zp and s8s8)
Vmm vmmb, int n,
bool is_tail_block); // int8 compensation dot_product (zp and s8s8)
void pad_comp_kernel(compute_pad_kernel_t kernel_type, int m_blocks,
int n_blocks, int padding, const Xbyak::Reg64 reg_pad,
const std::function<int(int)> &get_mi, bool has_tail = false);
Expand All @@ -360,6 +362,7 @@ struct jit_brdgmm_kernel_base_t : public jit_generator {
void apply_post_ops(int m_blocks, int n_blocks, bool has_n_tail);
void maybe_transpose_interleaved_vnni_to_plain(
int m_blocks, int n_blocks, bool has_n_tail);
void load_src_zp();
void compute_int8_compensation(int m_blocks, int n_blocks, bool has_n_tail);
void store_accumulators(int m_blocks, int n_blocks, bool has_n_tail);
void store_accumulators_without_post_ops(
Expand Down
12 changes: 8 additions & 4 deletions src/cpu/x64/jit_brdgmm_dw_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,8 @@ status_t brdgmm_dw_convolution_fwd_t::pd_t::init(engine_t *engine) {
const bool params_ok
= IMPLICATION(has_zero_points, utils::one_of(jcp.src_dt, u8, s8))
&& IMPLICATION(jcp.src_zero_point,
attr()->zero_points_.common(DNNL_ARG_SRC))
attr()->zero_points_.common(DNNL_ARG_SRC)
|| attr()->zero_points_.per_oc(DNNL_ARG_SRC))
&& IMPLICATION(jcp.dst_zero_point,
attr()->zero_points_.common(DNNL_ARG_DST));
VDISPATCH_CONV(params_ok, VERBOSE_UNSUPPORTED_ZP_CFG);
Expand Down Expand Up @@ -583,7 +584,7 @@ status_t brdgmm_dw_convolution_fwd_t::execute(const exec_ctx_t &ctx) const {
DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);

DEFINE_ZERO_POINT_VALUE(src_zero_point, DNNL_ARG_SRC);
DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC);
DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST);

const int wei_scale_mask
Expand Down Expand Up @@ -753,8 +754,11 @@ status_t brdgmm_dw_convolution_fwd_t::execute(const exec_ctx_t &ctx) const {
post_ops_data.scales = &oscales[jcp.is_oc_scale * ch];
post_ops_data.oc_logical_off = ch;
post_ops_data.dst_scales = dst_scales;
post_ops_data.zp_a_val
= jcp.src_zero_point ? src_zero_point : 1;
const bool is_bcast_zp
= pd()->attr()->zero_points_.common(DNNL_ARG_SRC);
post_ops_data.a_zp_values = jcp.src_zero_point
? src_zero_point + ch * !is_bcast_zp
: nullptr;
post_ops_data.c_zp_values
= jcp.dst_zero_point ? dst_zero_point : nullptr;
post_ops_data.a_zp_compensations
Expand Down

0 comments on commit 0741e51

Please sign in to comment.