Skip to content

Commit

Permalink
xe: jit: gemm: enable 0D/1D early dequantization cases
Browse files Browse the repository at this point in the history
  • Loading branch information
petercad committed Jul 23, 2024
1 parent 0f63229 commit baa6eaf
Show file tree
Hide file tree
Showing 7 changed files with 235 additions and 167 deletions.
7 changes: 1 addition & 6 deletions src/gpu/intel/jit/gemm/gen_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ struct gen_gemm_t : public gpu_gemm_t {
| smask_t::zero_points_runtime_groups;
}

bool wei_zp = false, wei_zp_2d = false;
bool wei_zp_2d = false;
auto wei_scales_type = data_type::undef;
auto src_scales_type = data_type::undef;
int wei_q2d_group_k = 0;
Expand Down Expand Up @@ -203,7 +203,6 @@ struct gen_gemm_t : public gpu_gemm_t {
CHECK(attr_zps.get(DNNL_ARG_B, &cmask_b));
CHECK(attr_zps.get(DNNL_ARG_C, &cmask_c));

wei_zp = a_zp;
wei_zp_2d = attr_zps.get_groups_ndims(DNNL_ARG_A) > 1;
VDISPATCH_GEMM(
(utils::one_of(cmask_a, 0, 1 << 1, 1 << 2) || wei_zp_2d)
Expand All @@ -228,7 +227,6 @@ struct gen_gemm_t : public gpu_gemm_t {
auto &src_scales = attr()->scales_.get(DNNL_ARG_SRC);

if (quant_enabled_ && wei_scales.ndims_ > 1) wei_scales_2d_ = true;

if (quant_enabled_ && src_scales.ndims_ > 1) src_scales_2d_ = true;

for (auto s : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) {
Expand All @@ -245,9 +243,6 @@ struct gen_gemm_t : public gpu_gemm_t {
if (scales_group_k >= d->k()) {
wei_scales_2d_ = false;
} else {
VDISPATCH_GEMM(
!(wei_zp && (ao_dims_ == 1 || bo_dims_ == 1)),
VERBOSE_UNSUPPORTED_ZP_CFG);
wei_scales_type = wei_scales.data_type_;
if (!wei_zp_2d)
wei_q2d_group_k = scales_group_k;
Expand Down
9 changes: 4 additions & 5 deletions src/gpu/intel/jit/gemm/gen_gemm_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,9 +362,9 @@ status_t gen_gemm_nocopy_kernel_desc_t::select_kernel(compute::gpu_arch_t arch,
}
}

if (problem_.Ta_ext.isInt4() && problem_.Tb_ext.isInt8() && ao_dims >= 2)
if (problem_.Ta_ext.isInt4() && problem_.Tb_ext.isInt8() && ao_dims >= 0)
problem_.Ta = Type::s8;
if (problem_.Tb_ext.isInt4() && problem_.Ta_ext.isInt8() && bo_dims >= 2)
if (problem_.Tb_ext.isInt4() && problem_.Ta_ext.isInt8() && bo_dims >= 0)
problem_.Tb = Type::s8;

if (problem_.Ta.isInteger()) problem_.Ts = Type::f32;
Expand All @@ -391,7 +391,7 @@ status_t gen_gemm_nocopy_kernel_desc_t::select_kernel(compute::gpu_arch_t arch,
MatchParams match_params[3];
int npatterns = 1;

match_params[0] = MatchParams(hw_, problem_);
match_params[0] = MatchParams(hw_, has_systolic, problem_);

match_params[0].sizes.m = m;
match_params[0].sizes.n = n;
Expand All @@ -405,7 +405,6 @@ status_t gen_gemm_nocopy_kernel_desc_t::select_kernel(compute::gpu_arch_t arch,
if (can_2d_a) *tags++ = kcatalog::ReqBlock2DA;
if (can_2d_b) *tags++ = kcatalog::ReqBlock2DB;
if (can_2d_c) *tags++ = kcatalog::ReqBlock2DC;
if (has_systolic) *tags++ = kcatalog::ReqSystolic;

if ((mode & mode_tf32)
&& utils::everyone_is(Type::f32, problem_.Ta, problem_.Tb)) {
Expand Down Expand Up @@ -600,7 +599,7 @@ status_t gen_gemm_xe_systolic_kernel_desc_t::select_kernel(
}

// Find it in the catalog.
MatchParams match_params(hw_, problem_);
MatchParams match_params(hw_, true, problem_);

match_params.sizes.m = m;
match_params.sizes.n = n;
Expand Down
Loading

0 comments on commit baa6eaf

Please sign in to comment.