Skip to content

Commit

Permalink
src: common: make rnn_s8s8_compensation a power of 2
Browse files Browse the repository at this point in the history
  • Loading branch information
hidefromkgb committed Jan 10, 2025
1 parent 3d833ff commit d27b7b7
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 71 deletions.
14 changes: 7 additions & 7 deletions src/common/memory_desc.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2024 Intel Corporation
* Copyright 2024-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -56,21 +56,21 @@ const rnn_packed_memory_format_t ldio_p = rnn_packed_memory_format_t::ldio_p;
// TODO: convert to 'enum class'.
// Flags for memory special features
enum memory_extra_flags_t {
dnnl_memory_extra_flag_none = 0x0U,
dnnl_memory_extra_flag_none = 0u,
// Indicates the weights have an additional buffer, that depends on the
// @p compensation_mask.
//
// For instance, in 4D case with the compensation mask equals (1 << 0)
// the additional buffer would consist of OC values:
// O[oc : 0,OC] =
// -128 * SUM(ic : 0,IC; kh : 0,KH; kw : 0,KW){ weights(oc, ic, kh, kw) }
dnnl_memory_extra_flag_compensation_conv_s8s8 = 0x1U,
dnnl_memory_extra_flag_scale_adjust = 0x2U,
dnnl_memory_extra_flag_rnn_u8s8_compensation = 0x4U,
dnnl_memory_extra_flag_compensation_conv_s8s8 = 1u,
dnnl_memory_extra_flag_scale_adjust = 2u,
dnnl_memory_extra_flag_rnn_u8s8_compensation = 4u,
dnnl_memory_extra_flag_gpu_rnn_u8s8_compensation
= dnnl_memory_extra_flag_rnn_u8s8_compensation,
dnnl_memory_extra_flag_compensation_conv_asymmetric_src = 0x8U,
dnnl_memory_extra_flag_rnn_s8s8_compensation = 0x16U,
dnnl_memory_extra_flag_compensation_conv_asymmetric_src = 8u,
dnnl_memory_extra_flag_rnn_s8s8_compensation = 16u,
};

// Create aliases for extra flags to preserve the old behavior.
Expand Down
28 changes: 10 additions & 18 deletions src/common/memory_desc_wrapper.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2016-2024 Intel Corporation
* Copyright 2016-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -149,9 +149,7 @@ struct memory_desc_wrapper : public c_compatible {
size_t additional_buffer_data_size(uint64_t flag_select) const {
using namespace memory_extra_flags;
if (flag_select & compensation_conv_s8s8) return sizeof(int32_t);
if ((flag_select & rnn_u8s8_compensation)
&& !types::extra_flag_rnn_s8s8_compensation_is_set(flag_select))
return sizeof(float);
if (flag_select & rnn_u8s8_compensation) return sizeof(float);
if (flag_select & compensation_conv_asymmetric_src)
return sizeof(int32_t);
return 0;
Expand All @@ -160,19 +158,16 @@ struct memory_desc_wrapper : public c_compatible {
/** return true if memory format has additional buffer */
bool is_additional_buffer() const {
using namespace memory_extra_flags;
// Currently compensation is not required for rnn_s8s8_compensation,
// but it has common bit with rnn_u8s8_compensation constant so we have
// to exclude rnn_s8s8_compensation case explicitly
return ((extra().flags
& (compensation_conv_s8s8 | rnn_u8s8_compensation
| compensation_conv_asymmetric_src))
&& !types::extra_flag_rnn_s8s8_compensation_is_set(
extra().flags));
return extra().flags
& (compensation_conv_s8s8 | rnn_u8s8_compensation
| compensation_conv_asymmetric_src);
}

/** returns the size required for a particular extra memory buffer */
size_t additional_buffer_size(memory_extra_flags_t flag) const {
using namespace memory_extra_flags;
const auto flags = extra().flags;
if (!(flags & flag)) return 0;

const auto ndims = this->ndims();
const auto &pdims = padded_dims();
Expand All @@ -186,18 +181,15 @@ struct memory_desc_wrapper : public c_compatible {
return (size_t)prod * buff_data_size;
};

if (extra().flags & compensation_conv_s8s8) {
if (flag == compensation_conv_s8s8) {
return calculate_size(extra().compensation_mask,
additional_buffer_data_size(flag));
}

if ((extra().flags & rnn_u8s8_compensation)
&& !types::extra_flag_rnn_s8s8_compensation_is_set(
extra().flags)) {
if (flag == rnn_u8s8_compensation) {
return calculate_size(extra().compensation_mask,
additional_buffer_data_size(flag));
}
if (extra().flags & compensation_conv_asymmetric_src) {
if (flag == compensation_conv_asymmetric_src) {
return calculate_size(extra().asymm_compensation_mask,
additional_buffer_data_size(flag));
}
Expand Down
8 changes: 3 additions & 5 deletions src/common/primitive_hashing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,9 @@ size_t get_md_hash(const memory_desc_t &md) {

if (md.extra.flags != dnnl_memory_extra_flag_none) {
seed = hash_combine(seed, md.extra.flags);
if ((md.extra.flags
& (dnnl_memory_extra_flag_compensation_conv_s8s8
| dnnl_memory_extra_flag_rnn_u8s8_compensation))
&& !types::extra_flag_rnn_s8s8_compensation_is_set(
md.extra.flags)) {
if (md.extra.flags
& (dnnl_memory_extra_flag_compensation_conv_s8s8
| dnnl_memory_extra_flag_rnn_u8s8_compensation)) {
seed = hash_combine(seed, md.extra.compensation_mask);
}

Expand Down
10 changes: 3 additions & 7 deletions src/common/serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,18 +120,14 @@ void serialize_md(serialization_stream_t &sstream, const memory_desc_t &md) {

if (md.extra.flags != dnnl_memory_extra_flag_none) {
sstream.write(&md.extra.flags);
if ((md.extra.flags
& (dnnl_memory_extra_flag_compensation_conv_s8s8
| dnnl_memory_extra_flag_rnn_u8s8_compensation))
&& !types::extra_flag_rnn_s8s8_compensation_is_set(
md.extra.flags)) {
if (md.extra.flags
& (dnnl_memory_extra_flag_compensation_conv_s8s8
| dnnl_memory_extra_flag_rnn_u8s8_compensation)) {
sstream.write(&md.extra.compensation_mask);
}

if (md.extra.flags & dnnl_memory_extra_flag_scale_adjust) {
sstream.write(&md.extra.scale_adjust);
}

if (md.extra.flags
& dnnl_memory_extra_flag_compensation_conv_asymmetric_src) {
sstream.write(&md.extra.asymm_compensation_mask);
Expand Down
19 changes: 3 additions & 16 deletions src/common/type_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,28 +299,15 @@ inline format_kind_t format_tag_to_kind(format_tag_t tag) {
return format_kind::undef;
}

// Currently rnn_s8s8_compensation has common bits with rnn_u8s8_compensation
// and scale_adjust constants so we have to perform additional checks to
// separate these two cases
inline bool extra_flag_rnn_s8s8_compensation_is_set(uint64_t flags) {
return ((flags & memory_extra_flags::rnn_s8s8_compensation)
^ memory_extra_flags::rnn_s8s8_compensation)
== 0;
}

inline bool memory_extra_desc_is_equal(
const memory_extra_desc_t &lhs, const memory_extra_desc_t &rhs) {
using namespace memory_extra_flags;
return true && lhs.flags == rhs.flags
return lhs.flags == rhs.flags
&& IMPLICATION(lhs.flags & compensation_conv_s8s8,
lhs.compensation_mask == rhs.compensation_mask)
&& IMPLICATION((lhs.flags & rnn_u8s8_compensation)
&& !extra_flag_rnn_s8s8_compensation_is_set(
lhs.flags),
&& IMPLICATION(lhs.flags & rnn_u8s8_compensation,
lhs.compensation_mask == rhs.compensation_mask)
&& IMPLICATION((lhs.flags & scale_adjust)
&& !extra_flag_rnn_s8s8_compensation_is_set(
lhs.flags),
&& IMPLICATION(lhs.flags & scale_adjust,
lhs.scale_adjust == rhs.scale_adjust)
&& IMPLICATION(lhs.flags & compensation_conv_asymmetric_src,
lhs.asymm_compensation_mask == rhs.asymm_compensation_mask);
Expand Down
11 changes: 2 additions & 9 deletions src/cpu/rnn/rnn_reorders.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2018-2024 Intel Corporation
* Copyright 2018-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -779,12 +779,7 @@ struct rnn_brgemm_weights_reorder_s8_t : public primitive_t {
return unimplemented;

// Check the proper memory desc has been passed to u8s8 and s8s8
// Note: currently rnn_u8s8_compensation and rnn_s8s8_compensation
// have common bit so we have to perform additional checks to
// separate these two cases
const bool check_u8s8 = (od.extra().flags & rnn_u8s8_compensation)
&& !types::extra_flag_rnn_s8s8_compensation_is_set(
od.extra().flags)
&& od.extra().compensation_mask
== ((id.ndims() == 5) ? 27 /* 11011 */
: 13 /* 1101 */);
Expand Down Expand Up @@ -886,9 +881,7 @@ struct rnn_brgemm_weights_reorder_s8_t : public primitive_t {
.template get<void>(memory_tracking::names::
key_reorder_rnn_weights_reduction);
float *comp = reinterpret_cast<float *>(dst + compensation_offset);
const bool req_s8s8_comp = (dst_d.extra().flags & rnn_u8s8_compensation)
&& !types::extra_flag_rnn_s8s8_compensation_is_set(
dst_d.extra().flags);
const bool req_s8s8_comp = dst_d.extra().flags & rnn_u8s8_compensation;
const auto mask_ok = [&](int mask) {
return mask
== ((src_d.ndims() == 5) ? 27 /* 11011 */
Expand Down
12 changes: 3 additions & 9 deletions src/gpu/intel/ocl/rnn/rnn_reorders.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2019-2024 Intel Corporation
* Copyright 2019-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -42,14 +42,8 @@ struct rnn_weights_reorder_t : public gpu_primitive_t {

status_t init(impl::engine_t *engine, impl::engine_t *src_engine,
impl::engine_t *dst_engine) {
// Note: currently rnn_u8s8_compensation and rnn_s8s8_compensation
// have common bit so we have to perform additional checks to
// separate these two cases
VDISPATCH_REORDER(
!IMPLICATION(dst_md()->extra.flags
& memory_extra_flags::rnn_u8s8_compensation,
types::extra_flag_rnn_s8s8_compensation_is_set(
dst_md()->extra.flags)),
VDISPATCH_REORDER(dst_md()->extra.flags
& memory_extra_flags::rnn_u8s8_compensation,
VERBOSE_BAD_FLAGS);

VDISPATCH_REORDER(utils::one_of(src_engine->kind(),
Expand Down

0 comments on commit d27b7b7

Please sign in to comment.