Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

src: gpu: intel: jit: conv: add reorder-based precomputed zero points #2267

Merged
merged 2 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/common/memory_desc.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2022-2024 Intel Corporation
* Copyright 2022-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -471,8 +471,9 @@ status_t memory_desc_permute_axes(memory_desc_t &out_memory_desc,
VCHECK_MEMORY(
!memory_desc_wrapper(in_memory_desc).has_runtime_dims_or_strides(),
invalid_arguments, VERBOSE_UNSUPPORTED_MEM_STRIDE);
VCHECK_MEMORY(in_memory_desc.extra.flags == 0, invalid_arguments,
VERBOSE_UNSUPPORTED_MD_FLAG, "extra");
VCHECK_MEMORY(
check_md_extra_flags_compensation_gpu(in_memory_desc.extra.flags),
invalid_arguments, VERBOSE_UNSUPPORTED_MD_FLAG, "extra");

// verify that perm is indeed a permutation of [0 .. ndims)
unsigned occurrence_mask = 0;
Expand Down
55 changes: 47 additions & 8 deletions src/common/memory_desc.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2024 Intel Corporation
* Copyright 2024-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -56,21 +56,30 @@ const rnn_packed_memory_format_t ldio_p = rnn_packed_memory_format_t::ldio_p;
// TODO: convert to 'enum class'.
// Flags for memory special features
enum memory_extra_flags_t {
dnnl_memory_extra_flag_none = 0x0U,
dnnl_memory_extra_flag_none = 0u,
// Indicates the weights have an additional buffer, that depends on the
// @p compensation_mask.
//
// For instance, in 4D case with the compensation mask equals (1 << 0)
// the additional buffer would consist of OC values:
// O[oc : 0,OC] =
// -128 * SUM(ic : 0,IC; kh : 0,KH; kw : 0,KW){ weights(oc, ic, kh, kw) }
dnnl_memory_extra_flag_compensation_conv_s8s8 = 0x1U,
dnnl_memory_extra_flag_scale_adjust = 0x2U,
dnnl_memory_extra_flag_rnn_u8s8_compensation = 0x4U,
dnnl_memory_extra_flag_compensation_conv_s8s8 = 1u,
dnnl_memory_extra_flag_scale_adjust = 2u,
dnnl_memory_extra_flag_rnn_u8s8_compensation = 4u,
dnnl_memory_extra_flag_gpu_rnn_u8s8_compensation
= dnnl_memory_extra_flag_rnn_u8s8_compensation,
dnnl_memory_extra_flag_compensation_conv_asymmetric_src = 0x8U,
dnnl_memory_extra_flag_rnn_s8s8_compensation = 0x16U,
dnnl_memory_extra_flag_compensation_conv_asymmetric_src = 8u,
dnnl_memory_extra_flag_rnn_s8s8_compensation = 16u,
// This flag has to be kept separate from *compensation_conv_asymmetric_src
// since the GPU precompute algorithm is incompatible with that of the CPU
dnnl_memory_extra_flag_compensation_gpu_conv_asymmetric_src = 32u,
// This flag depends on *compensation_gpu_conv_asymmetric_src and is used
// when precompute is to be performed for a backward-by-data convolution
dnnl_memory_extra_flag_compensation_gpu_conv_asymmetric_src_bwd = 64u,
// This flag depends on *compensation_gpu_conv_asymmetric_src and is used
// when IC and OC are swapped to reinterpret a deconv as a BWD_D conv
dnnl_memory_extra_flag_compensation_gpu_conv_asymmetric_src_swap = 128u,
};

// Create aliases for extra flags to preserve the old behavior.
Expand All @@ -87,8 +96,23 @@ const memory_extra_flags_t rnn_s8s8_compensation
= dnnl_memory_extra_flag_rnn_s8s8_compensation;
const memory_extra_flags_t compensation_conv_asymmetric_src
= dnnl_memory_extra_flag_compensation_conv_asymmetric_src;
const memory_extra_flags_t compensation_gpu_conv_asymmetric_src
= dnnl_memory_extra_flag_compensation_gpu_conv_asymmetric_src;
const memory_extra_flags_t compensation_gpu_conv_asymmetric_src_bwd
= dnnl_memory_extra_flag_compensation_gpu_conv_asymmetric_src_bwd;
const memory_extra_flags_t compensation_gpu_conv_asymmetric_src_swap
= dnnl_memory_extra_flag_compensation_gpu_conv_asymmetric_src_swap;
} // namespace memory_extra_flags

inline bool check_md_extra_flags_compensation_gpu(uint64_t flags) {
using namespace memory_extra_flags;
const uint64_t c = compensation_gpu_conv_asymmetric_src;
const uint64_t b = compensation_gpu_conv_asymmetric_src_bwd;
const uint64_t s = compensation_gpu_conv_asymmetric_src_swap;
return (flags == none) || (flags == c) || (flags == (c | b))
|| (flags == (c | b | s));
}

// Generic description of blocked data layout for most memory formats.
struct blocking_desc_t {
// The strides between the outermost blocks.
Expand Down Expand Up @@ -208,7 +232,12 @@ struct memory_extra_desc_t {
: flags(0)
, compensation_mask(0)
, scale_adjust(0.0f)
, asymm_compensation_mask(0) {}
, asymm_compensation_mask(0)
, idhw {0, 0, 0}
, odhw {0, 0, 0}
, pdhw {0, 0, 0}
, ddhw {0, 0, 0}
, dst_size(0) {}
// The flags contain arbitrary extra information, such as compensation.
// @sa dnnl_memory_extra_flags_t
uint64_t flags;
Expand All @@ -218,6 +247,16 @@ struct memory_extra_desc_t {
float scale_adjust;
// Compensation mask for asymmetric quantization
int asymm_compensation_mask;
// Precomp GPU ZP convolution input spatials
dim_t idhw[3];
// Precomp GPU ZP convolution output spatials
dim_t odhw[3];
// Precomp GPU ZP convolution padding spatials
dim_t pdhw[3];
// Precomp GPU ZP convolution dilation spatials
dim_t ddhw[3];
// Precomp GPU ZP convolution destination size
dim_t dst_size;
};

status_t DNNL_API memory_desc_init_by_tag(memory_desc_t &memory_desc, int ndims,
Expand Down
36 changes: 18 additions & 18 deletions src/common/memory_desc_wrapper.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2016-2024 Intel Corporation
* Copyright 2016-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -149,30 +149,28 @@ struct memory_desc_wrapper : public c_compatible {
size_t additional_buffer_data_size(uint64_t flag_select) const {
using namespace memory_extra_flags;
if (flag_select & compensation_conv_s8s8) return sizeof(int32_t);
if ((flag_select & rnn_u8s8_compensation)
&& !types::extra_flag_rnn_s8s8_compensation_is_set(flag_select))
return sizeof(float);
if (flag_select & rnn_u8s8_compensation) return sizeof(float);
if (flag_select & compensation_conv_asymmetric_src)
return sizeof(int32_t);
if (flag_select & compensation_gpu_conv_asymmetric_src)
return sizeof(int32_t);
return 0;
}

/** return true if memory format has additional buffer */
bool is_additional_buffer() const {
using namespace memory_extra_flags;
// Currently compensation is not required for rnn_s8s8_compensation,
// but it has common bit with rnn_u8s8_compensation constant so we have
// to exclude rnn_s8s8_compensation case explicitly
return ((extra().flags
& (compensation_conv_s8s8 | rnn_u8s8_compensation
| compensation_conv_asymmetric_src))
&& !types::extra_flag_rnn_s8s8_compensation_is_set(
extra().flags));
return extra().flags
& (compensation_conv_s8s8 | rnn_u8s8_compensation
| compensation_gpu_conv_asymmetric_src
| compensation_conv_asymmetric_src);
}

/** returns the size required for a particular extra memory buffer */
size_t additional_buffer_size(memory_extra_flags_t flag) const {
using namespace memory_extra_flags;
const auto flags = extra().flags;
if (!(flags & flag)) return 0;

const auto ndims = this->ndims();
const auto &pdims = padded_dims();
Expand All @@ -186,21 +184,21 @@ struct memory_desc_wrapper : public c_compatible {
return (size_t)prod * buff_data_size;
};

if (extra().flags & compensation_conv_s8s8) {
if (flag == compensation_conv_s8s8) {
return calculate_size(extra().compensation_mask,
additional_buffer_data_size(flag));
}

if ((extra().flags & rnn_u8s8_compensation)
&& !types::extra_flag_rnn_s8s8_compensation_is_set(
extra().flags)) {
if (flag == rnn_u8s8_compensation) {
return calculate_size(extra().compensation_mask,
additional_buffer_data_size(flag));
}
if (extra().flags & compensation_conv_asymmetric_src) {
if (flag == compensation_conv_asymmetric_src) {
return calculate_size(extra().asymm_compensation_mask,
additional_buffer_data_size(flag));
}
if (flag == compensation_gpu_conv_asymmetric_src) {
return extra().dst_size;
}

return 0;
}
Expand All @@ -220,6 +218,8 @@ struct memory_desc_wrapper : public c_compatible {
buff_size += additional_buffer_size(compensation_conv_s8s8);
buff_size += additional_buffer_size(rnn_u8s8_compensation);
buff_size += additional_buffer_size(compensation_conv_asymmetric_src);
buff_size
+= additional_buffer_size(compensation_gpu_conv_asymmetric_src);
return buff_size;
}

Expand Down
17 changes: 12 additions & 5 deletions src/common/primitive_hashing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,9 @@ size_t get_md_hash(const memory_desc_t &md) {

if (md.extra.flags != dnnl_memory_extra_flag_none) {
seed = hash_combine(seed, md.extra.flags);
if ((md.extra.flags
& (dnnl_memory_extra_flag_compensation_conv_s8s8
| dnnl_memory_extra_flag_rnn_u8s8_compensation))
&& !types::extra_flag_rnn_s8s8_compensation_is_set(
md.extra.flags)) {
if (md.extra.flags
& (dnnl_memory_extra_flag_compensation_conv_s8s8
| dnnl_memory_extra_flag_rnn_u8s8_compensation)) {
seed = hash_combine(seed, md.extra.compensation_mask);
}

Expand All @@ -206,6 +204,15 @@ size_t get_md_hash(const memory_desc_t &md) {
& dnnl_memory_extra_flag_compensation_conv_asymmetric_src) {
seed = hash_combine(seed, md.extra.asymm_compensation_mask);
}

if (md.extra.flags
& dnnl_memory_extra_flag_compensation_gpu_conv_asymmetric_src) {
seed = get_array_hash(seed, md.extra.idhw, 3);
seed = get_array_hash(seed, md.extra.odhw, 3);
seed = get_array_hash(seed, md.extra.pdhw, 3);
seed = get_array_hash(seed, md.extra.ddhw, 3);
seed = hash_combine(seed, md.extra.dst_size);
}
}
// Combined hash for a memory descriptor
return seed;
Expand Down
18 changes: 11 additions & 7 deletions src/common/serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,22 +120,26 @@ void serialize_md(serialization_stream_t &sstream, const memory_desc_t &md) {

if (md.extra.flags != dnnl_memory_extra_flag_none) {
sstream.write(&md.extra.flags);
if ((md.extra.flags
& (dnnl_memory_extra_flag_compensation_conv_s8s8
| dnnl_memory_extra_flag_rnn_u8s8_compensation))
&& !types::extra_flag_rnn_s8s8_compensation_is_set(
md.extra.flags)) {
if (md.extra.flags
& (dnnl_memory_extra_flag_compensation_conv_s8s8
| dnnl_memory_extra_flag_rnn_u8s8_compensation)) {
sstream.write(&md.extra.compensation_mask);
}

if (md.extra.flags & dnnl_memory_extra_flag_scale_adjust) {
sstream.write(&md.extra.scale_adjust);
}

if (md.extra.flags
& dnnl_memory_extra_flag_compensation_conv_asymmetric_src) {
sstream.write(&md.extra.asymm_compensation_mask);
}
if (md.extra.flags
& dnnl_memory_extra_flag_compensation_gpu_conv_asymmetric_src) {
sstream.write(md.extra.idhw, 3);
sstream.write(md.extra.odhw, 3);
sstream.write(md.extra.pdhw, 3);
sstream.write(md.extra.ddhw, 3);
sstream.write(&md.extra.dst_size);
}
}
}

Expand Down
27 changes: 10 additions & 17 deletions src/common/type_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,31 +299,24 @@ inline format_kind_t format_tag_to_kind(format_tag_t tag) {
return format_kind::undef;
}

// Currently rnn_s8s8_compensation has common bits with rnn_u8s8_compensation
// and scale_adjust constants so we have to perform additional checks to
// separate these two cases
inline bool extra_flag_rnn_s8s8_compensation_is_set(uint64_t flags) {
return ((flags & memory_extra_flags::rnn_s8s8_compensation)
^ memory_extra_flags::rnn_s8s8_compensation)
== 0;
}

inline bool memory_extra_desc_is_equal(
const memory_extra_desc_t &lhs, const memory_extra_desc_t &rhs) {
using namespace memory_extra_flags;
return true && lhs.flags == rhs.flags
return lhs.flags == rhs.flags
&& IMPLICATION(lhs.flags & compensation_conv_s8s8,
lhs.compensation_mask == rhs.compensation_mask)
&& IMPLICATION((lhs.flags & rnn_u8s8_compensation)
&& !extra_flag_rnn_s8s8_compensation_is_set(
lhs.flags),
&& IMPLICATION(lhs.flags & rnn_u8s8_compensation,
lhs.compensation_mask == rhs.compensation_mask)
&& IMPLICATION((lhs.flags & scale_adjust)
&& !extra_flag_rnn_s8s8_compensation_is_set(
lhs.flags),
&& IMPLICATION(lhs.flags & scale_adjust,
lhs.scale_adjust == rhs.scale_adjust)
&& IMPLICATION(lhs.flags & compensation_conv_asymmetric_src,
lhs.asymm_compensation_mask == rhs.asymm_compensation_mask);
lhs.asymm_compensation_mask == rhs.asymm_compensation_mask)
&& IMPLICATION(lhs.flags & compensation_gpu_conv_asymmetric_src,
(lhs.dst_size == rhs.dst_size)
&& utils::array_cmp(lhs.idhw, rhs.idhw, 3)
&& utils::array_cmp(lhs.odhw, rhs.odhw, 3)
&& utils::array_cmp(lhs.pdhw, rhs.pdhw, 3)
&& utils::array_cmp(lhs.ddhw, rhs.ddhw, 3));
}

inline bool blocking_desc_is_equal(const memory_desc_t &lhs_md,
Expand Down
15 changes: 15 additions & 0 deletions src/common/verbose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,21 @@ std::ostream &operator<<(std::ostream &ss, const memory_extra_desc_t &extra) {
ss << ":s8m" << extra.compensation_mask;
if (extra.flags & compensation_conv_asymmetric_src)
ss << ":zpm" << extra.asymm_compensation_mask;
if (extra.flags & compensation_gpu_conv_asymmetric_src) {
ss << ":zid" << extra.idhw[0];
ss << ":zih" << extra.idhw[1];
ss << ":ziw" << extra.idhw[2];
ss << ":zod" << extra.odhw[0];
ss << ":zoh" << extra.odhw[1];
ss << ":zow" << extra.odhw[2];
ss << ":zpd" << extra.pdhw[0];
ss << ":zph" << extra.pdhw[1];
ss << ":zpw" << extra.pdhw[2];
ss << ":zdd" << extra.ddhw[0];
ss << ":zdh" << extra.ddhw[1];
ss << ":zdw" << extra.ddhw[2];
ss << ":zs" << extra.dst_size;
}
if (extra.flags & scale_adjust && extra.scale_adjust != 1.f)
ss << ":sa" << extra.scale_adjust;
return ss;
Expand Down
5 changes: 4 additions & 1 deletion src/cpu/reorder/cpu_reorder_pd.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2016-2024 Intel Corporation
* Copyright 2016-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -38,6 +38,9 @@ struct cpu_reorder_pd_t : public reorder_pd_t {
post_ops.len() == 1
&& post_ops.entry_[0].kind == primitive_kind::sum);
VDISPATCH_REORDER(args_ok, VERBOSE_UNSUPPORTED_POSTOP);
auto gpu_zp = memory_extra_flags::compensation_gpu_conv_asymmetric_src;
VDISPATCH_REORDER(!(dst_md()->extra.flags & gpu_zp),
VERBOSE_UNSUPPORTED_MD_FLAG, "extra");
return status::success;
}

Expand Down
11 changes: 2 additions & 9 deletions src/cpu/rnn/rnn_reorders.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2018-2024 Intel Corporation
* Copyright 2018-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -779,12 +779,7 @@ struct rnn_brgemm_weights_reorder_s8_t : public primitive_t {
return unimplemented;

// Check the proper memory desc has been passed to u8s8 and s8s8
// Note: currently rnn_u8s8_compensation and rnn_s8s8_compensation
// have common bit so we have to perform additional checks to
// separate these two cases
const bool check_u8s8 = (od.extra().flags & rnn_u8s8_compensation)
&& !types::extra_flag_rnn_s8s8_compensation_is_set(
od.extra().flags)
&& od.extra().compensation_mask
== ((id.ndims() == 5) ? 27 /* 11011 */
: 13 /* 1101 */);
Expand Down Expand Up @@ -886,9 +881,7 @@ struct rnn_brgemm_weights_reorder_s8_t : public primitive_t {
.template get<void>(memory_tracking::names::
key_reorder_rnn_weights_reduction);
float *comp = reinterpret_cast<float *>(dst + compensation_offset);
const bool req_s8s8_comp = (dst_d.extra().flags & rnn_u8s8_compensation)
&& !types::extra_flag_rnn_s8s8_compensation_is_set(
dst_d.extra().flags);
const bool req_s8s8_comp = dst_d.extra().flags & rnn_u8s8_compensation;
const auto mask_ok = [&](int mask) {
return mask
== ((src_d.ndims() == 5) ? 27 /* 11011 */
Expand Down
Loading
Loading