Skip to content

Commit

Permalink
x64: ip: add avx2_vnni_2 xf16 tags
Browse files Browse the repository at this point in the history
  • Loading branch information
nivas-x86 authored and tprimak committed Apr 5, 2023
1 parent f5654f5 commit f8d7c2e
Show file tree
Hide file tree
Showing 8 changed files with 123 additions and 14 deletions.
16 changes: 16 additions & 0 deletions include/oneapi/dnnl/dnnl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2427,6 +2427,22 @@ struct memory : public handle<dnnl_memory_t> {
ABcde4b24a4b = dnnl_ABcde4b24a4b,
OhwI24o = dnnl_OhwI24o,
gOhwI24o = dnnl_gOhwI24o,
AB8b24a2b = dnnl_AB8b24a2b,
ABc8b24a2b = dnnl_ABc8b24a2b,
ABcd8b24a2b = dnnl_ABcd8b24a2b,
ABcde8b24a2b = dnnl_ABcde8b24a2b,
AB8b8a2b = dnnl_AB8b8a2b,
ABc8b8a2b = dnnl_ABc8b8a2b,
ABcd8b8a2b = dnnl_ABcd8b8a2b,
ABcde8b8a2b = dnnl_ABcde8b8a2b,
OI8i8o2i = dnnl_OI8i8o2i,
OI8i24o2i = dnnl_OI8i24o2i,
OIw8i8o2i = dnnl_OIw8i8o2i,
OIw8i24o2i = dnnl_OIw8i24o2i,
OIhw8i8o2i = dnnl_OIhw8i8o2i,
OIhw8i24o2i = dnnl_OIhw8i24o2i,
OIdhw8i8o2i = dnnl_OIdhw8i8o2i,
OIdhw8i24o2i = dnnl_OIdhw8i24o2i,
};

/// A memory descriptor.
Expand Down
16 changes: 16 additions & 0 deletions include/oneapi/dnnl/dnnl_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,14 @@ typedef enum {
dnnl_ABcd4b24a4b,
dnnl_ABcde4b8a4b,
dnnl_ABcde4b24a4b,
dnnl_AB8b24a2b,
dnnl_ABc8b24a2b,
dnnl_ABcd8b24a2b,
dnnl_ABcde8b24a2b,
dnnl_AB8b8a2b,
dnnl_ABc8b8a2b,
dnnl_ABcd8b8a2b,
dnnl_ABcde8b8a2b,

/// Just a sentinel, not real memory format tag. Must be changed after new
/// format tag is added.
Expand Down Expand Up @@ -1084,7 +1092,9 @@ typedef enum {
dnnl_OI16i16o = dnnl_AB16b16a,
dnnl_OI16i32o = dnnl_AB16b32a,
dnnl_OI16i64o = dnnl_AB16b64a,
dnnl_OI8i8o2i = dnnl_AB8b8a2b,
dnnl_OI8i16o2i = dnnl_AB8b16a2b,
dnnl_OI8i24o2i = dnnl_AB8b24a2b,
dnnl_OI8i32o2i = dnnl_AB8b32a2b,
dnnl_OI8i64o2i = dnnl_AB8b64a2b,
dnnl_OI4i8o4i = dnnl_AB4b8a4b,
Expand Down Expand Up @@ -1118,7 +1128,9 @@ typedef enum {
dnnl_OIw4i4o = dnnl_ABc4b4a,
dnnl_OIw4o4i = dnnl_ABc4a4b,
dnnl_Oiw4o = dnnl_Abc4a,
dnnl_OIw8i8o2i = dnnl_ABc8b8a2b,
dnnl_OIw8i16o2i = dnnl_ABc8b16a2b,
dnnl_OIw8i24o2i = dnnl_ABc8b24a2b,
dnnl_OIw8i32o2i = dnnl_ABc8b32a2b,
dnnl_OIw8i64o2i = dnnl_ABc8b64a2b,
dnnl_OIw8i8o = dnnl_ABc8b8a,
Expand Down Expand Up @@ -1169,8 +1181,10 @@ typedef enum {
dnnl_OIhw4i4o = dnnl_ABcd4b4a,
dnnl_OIhw4o4i = dnnl_ABcd4a4b,
dnnl_Oihw4o = dnnl_Abcd4a,
dnnl_OIhw8i8o2i = dnnl_ABcd8b8a2b,
dnnl_OIhw8i16o2i = dnnl_ABcd8b16a2b,
dnnl_OIhw8i32o2i = dnnl_ABcd8b32a2b,
dnnl_OIhw8i24o2i = dnnl_ABcd8b24a2b,
dnnl_OIhw8i64o2i = dnnl_ABcd8b64a2b,
dnnl_OIhw8i8o = dnnl_ABcd8b8a,
dnnl_OIhw8o16i2o = dnnl_ABcd8a16b2a,
Expand Down Expand Up @@ -1202,8 +1216,10 @@ typedef enum {
dnnl_OIdhw4i4o = dnnl_ABcde4b4a,
dnnl_OIdhw4o4i = dnnl_ABcde4a4b,
dnnl_Oidhw4o = dnnl_Abcde4a,
dnnl_OIdhw8i8o2i = dnnl_ABcde8b8a2b,
dnnl_OIdhw8i16o2i = dnnl_ABcde8b16a2b,
dnnl_OIdhw8i32o2i = dnnl_ABcde8b32a2b,
dnnl_OIdhw8i24o2i = dnnl_ABcde8b24a2b,
dnnl_OIdhw8i64o2i = dnnl_ABcde8b64a2b,
dnnl_OIdhw8i8o = dnnl_ABcde8b8a,
dnnl_OIdhw8o16i2o = dnnl_ABcde8a16b2a,
Expand Down
16 changes: 16 additions & 0 deletions src/common/c_types_map.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -878,6 +878,14 @@ const format_tag_t ABc8b16a = dnnl_ABc8b16a;
const format_tag_t ABcd8b16a = dnnl_ABcd8b16a;
const format_tag_t ABcde8b16a = dnnl_ABcde8b16a;
const format_tag_t AB8b8a = dnnl_AB8b8a;
const format_tag_t AB8b8a2b = dnnl_AB8b8a2b;
const format_tag_t ABc8b8a2b = dnnl_ABc8b8a2b;
const format_tag_t ABcd8b8a2b = dnnl_ABcd8b8a2b;
const format_tag_t ABcde8b8a2b = dnnl_ABcde8b8a2b;
const format_tag_t AB8b24a2b = dnnl_AB8b24a2b;
const format_tag_t ABc8b24a2b = dnnl_ABc8b24a2b;
const format_tag_t ABcd8b24a2b = dnnl_ABcd8b24a2b;
const format_tag_t ABcde8b24a2b = dnnl_ABcde8b24a2b;

const format_tag_t last = dnnl_format_tag_last;

Expand Down Expand Up @@ -1573,6 +1581,14 @@ const format_tag_t OIw8i16o = dnnl_OIw8i16o;
const format_tag_t OIhw8i16o = dnnl_OIhw8i16o;
const format_tag_t OIdhw8i16o = dnnl_OIdhw8i16o;
const format_tag_t OI8i8o = dnnl_OI8i8o;
const format_tag_t OI8i8o2i = dnnl_OI8i8o2i;
const format_tag_t OIw8i8o2i = dnnl_OIw8i8o2i;
const format_tag_t OIhw8i8o2i = dnnl_OIhw8i8o2i;
const format_tag_t OIdhw8i8o2i = dnnl_OIdhw8i8o2i;
const format_tag_t OI8i24o2i = dnnl_OI8i24o2i;
const format_tag_t OIw8i24o2i = dnnl_OIw8i24o2i;
const format_tag_t OIhw8i24o2i = dnnl_OIhw8i24o2i;
const format_tag_t OIdhw8i24o2i = dnnl_OIdhw8i24o2i;
} // namespace format_tag

using normalization_flags_t = dnnl_normalization_flags_t;
Expand Down
16 changes: 16 additions & 0 deletions src/common/dnnl_debug_autogenerated.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -784,6 +784,14 @@ const char *dnnl_fmt_tag2str(dnnl_format_tag_t v) {
if (v == dnnl_ABcd4b24a4b) return "ABcd4b24a4b";
if (v == dnnl_ABcde4b8a4b) return "ABcde4b8a4b";
if (v == dnnl_ABcde4b24a4b) return "ABcde4b24a4b";
if (v == dnnl_AB8b24a2b) return "AB8b24a2b";
if (v == dnnl_ABc8b24a2b) return "ABc8b24a2b";
if (v == dnnl_ABcd8b24a2b) return "ABcd8b24a2b";
if (v == dnnl_ABcde8b24a2b) return "ABcde8b24a2b";
if (v == dnnl_AB8b8a2b) return "AB8b8a2b";
if (v == dnnl_ABc8b8a2b) return "ABc8b8a2b";
if (v == dnnl_ABcd8b8a2b) return "ABcd8b8a2b";
if (v == dnnl_ABcde8b8a2b) return "ABcde8b8a2b";
if (v == dnnl_format_tag_last) return "format_tag_last";
if (v == dnnl_x) return "x";
if (v == dnnl_nc) return "nc";
Expand Down Expand Up @@ -873,7 +881,9 @@ const char *dnnl_fmt_tag2str(dnnl_format_tag_t v) {
if (v == dnnl_OI16i16o) return "OI16i16o";
if (v == dnnl_OI16i32o) return "OI16i32o";
if (v == dnnl_OI16i64o) return "OI16i64o";
if (v == dnnl_OI8i8o2i) return "OI8i8o2i";
if (v == dnnl_OI8i16o2i) return "OI8i16o2i";
if (v == dnnl_OI8i24o2i) return "OI8i24o2i";
if (v == dnnl_OI8i32o2i) return "OI8i32o2i";
if (v == dnnl_OI8i64o2i) return "OI8i64o2i";
if (v == dnnl_OI4i8o4i) return "OI4i8o4i";
Expand Down Expand Up @@ -905,7 +915,9 @@ const char *dnnl_fmt_tag2str(dnnl_format_tag_t v) {
if (v == dnnl_OIw4i4o) return "OIw4i4o";
if (v == dnnl_OIw4o4i) return "OIw4o4i";
if (v == dnnl_Oiw4o) return "Oiw4o";
if (v == dnnl_OIw8i8o2i) return "OIw8i8o2i";
if (v == dnnl_OIw8i16o2i) return "OIw8i16o2i";
if (v == dnnl_OIw8i24o2i) return "OIw8i24o2i";
if (v == dnnl_OIw8i32o2i) return "OIw8i32o2i";
if (v == dnnl_OIw8i64o2i) return "OIw8i64o2i";
if (v == dnnl_OIw8i8o) return "OIw8i8o";
Expand Down Expand Up @@ -954,8 +966,10 @@ const char *dnnl_fmt_tag2str(dnnl_format_tag_t v) {
if (v == dnnl_OIhw4i4o) return "OIhw4i4o";
if (v == dnnl_OIhw4o4i) return "OIhw4o4i";
if (v == dnnl_Oihw4o) return "Oihw4o";
if (v == dnnl_OIhw8i8o2i) return "OIhw8i8o2i";
if (v == dnnl_OIhw8i16o2i) return "OIhw8i16o2i";
if (v == dnnl_OIhw8i32o2i) return "OIhw8i32o2i";
if (v == dnnl_OIhw8i24o2i) return "OIhw8i24o2i";
if (v == dnnl_OIhw8i64o2i) return "OIhw8i64o2i";
if (v == dnnl_OIhw8i8o) return "OIhw8i8o";
if (v == dnnl_OIhw8o16i2o) return "OIhw8o16i2o";
Expand Down Expand Up @@ -985,8 +999,10 @@ const char *dnnl_fmt_tag2str(dnnl_format_tag_t v) {
if (v == dnnl_OIdhw4i4o) return "OIdhw4i4o";
if (v == dnnl_OIdhw4o4i) return "OIdhw4o4i";
if (v == dnnl_Oidhw4o) return "Oidhw4o";
if (v == dnnl_OIdhw8i8o2i) return "OIdhw8i8o2i";
if (v == dnnl_OIdhw8i16o2i) return "OIdhw8i16o2i";
if (v == dnnl_OIdhw8i32o2i) return "OIdhw8i32o2i";
if (v == dnnl_OIdhw8i24o2i) return "OIdhw8i24o2i";
if (v == dnnl_OIdhw8i64o2i) return "OIdhw8i64o2i";
if (v == dnnl_OIdhw8i8o) return "OIdhw8i8o";
if (v == dnnl_OIdhw8o16i2o) return "OIdhw8o16i2o";
Expand Down
9 changes: 9 additions & 0 deletions src/common/memory_desc_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,15 @@ status_t memory_desc_wrapper::compute_blocking(
C(ABcd8b16a, {0, 1, 2, 3}, {8, 16}, {1, 0});
C(ABcde8b16a, {0, 1, 2, 3, 4}, {8, 16}, {1, 0});
C(AB8b8a, {0, 1}, {8, 8}, {1, 0});

C(AB8b24a2b, {0, 1}, {8, 24, 2}, {1, 0, 1});
C(ABc8b24a2b, {0, 1, 2}, {8, 24, 2}, {1, 0, 1});
C(ABcd8b24a2b, {0, 1, 2, 3}, {8, 24, 2}, {1, 0, 1});
C(ABcde8b24a2b, {0, 1, 2, 3, 4}, {8, 24, 2}, {1, 0, 1});
C(AB8b8a2b, {0, 1}, {8, 8, 2}, {1, 0, 1});
C(ABc8b8a2b, {0, 1, 2}, {8, 8, 2}, {1, 0, 1});
C(ABcd8b8a2b, {0, 1, 2, 3}, {8, 8, 2}, {1, 0, 1});
C(ABcde8b8a2b, {0, 1, 2, 3, 4}, {8, 8, 2}, {1, 0, 1});
default: break;
}

Expand Down
40 changes: 27 additions & 13 deletions src/common/tag_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ enum class inner_blk_t {
_4b32a4b,
_4b64a4b,
_2b8a4b,
_8b8a2b,
_8b16a2b,
_8b24a2b,
_8b32a2b,
_8b64a2b,
_8b16c2b,
Expand Down Expand Up @@ -190,19 +192,20 @@ constexpr int AB_or_BC_blk_off(int x0, int x1) {
ib::_16b2c, ib::_16b4c, ib::_2c8b4c, ib::_8a16b2a,
ib::_4b64a4b, ib::_4b32a4b, ib::_4b24a4b, ib::_4b16a4b,
ib::_4b8a4b, ib::_2b8a4b, ib::_8b64a2b, ib::_8b32a2b,
ib::_8b16a2b, ib::_8b16c2b, ib::_4c16b4c, ib::_8c16b2c,
ib::_2b4c2b, ib::_2c4b2c, ib::_4b8c2b, ib::_4c8b2c,
ib::_16a32b, ib::_16a48b, ib::_16a64b, ib::_16a16b2a,
ib::_16a32b2a, ib::_16a48b2a, ib::_16a64b2a, ib::_16a16b4a,
ib::_16a32b4a, ib::_16a48b4a, ib::_16a64b4a, ib::_16b16a2b,
ib::_16b16a4b, ib::_16b16c2b, ib::_16c16b2c, ib::_16c16b4c,
ib::_2a8b8a2b, ib::_2b8c8b2c, ib::_4a8b8a4b, ib::_4b8c8b4c,
ib::_16b32a2b, ib::_16b48a2b, ib::_16b64a2b, ib::_16b32a4b,
ib::_16b48a4b, ib::_16b64a4b, ib::_16c32b2c, ib::_16c48b2c,
ib::_16c64b2c, ib::_16c32b4c, ib::_16c48b4c, ib::_16c64b4c,
ib::_16b32c, ib::_16b48c, ib::_16b64c, ib::_16b32c2b,
ib::_16b48c2b, ib::_16b64c2b, ib::_16b16c4b, ib::_16b32c4b,
ib::_16b48c4b, ib::_16b64c4b, ib::_24a2b, ib::_24b2c),
ib::_8b24a2b, ib::_8b16a2b, ib::_8b8a2b, ib::_8b16c2b,
ib::_4c16b4c, ib::_8c16b2c, ib::_2b4c2b, ib::_2c4b2c,
ib::_4b8c2b, ib::_4c8b2c, ib::_16a32b, ib::_16a48b,
ib::_16a64b, ib::_16a16b2a, ib::_16a32b2a, ib::_16a48b2a,
ib::_16a64b2a, ib::_16a16b4a, ib::_16a32b4a, ib::_16a48b4a,
ib::_16a64b4a, ib::_16b16a2b, ib::_16b16a4b, ib::_16b16c2b,
ib::_16c16b2c, ib::_16c16b4c, ib::_2a8b8a2b, ib::_2b8c8b2c,
ib::_4a8b8a4b, ib::_4b8c8b4c, ib::_16b32a2b, ib::_16b48a2b,
ib::_16b64a2b, ib::_16b32a4b, ib::_16b48a4b, ib::_16b64a4b,
ib::_16c32b2c, ib::_16c48b2c, ib::_16c64b2c, ib::_16c32b4c,
ib::_16c48b4c, ib::_16c64b4c, ib::_16b32c, ib::_16b48c,
ib::_16b64c, ib::_16b32c2b, ib::_16b48c2b, ib::_16b64c2b,
ib::_16b16c4b, ib::_16b32c4b, ib::_16b48c4b, ib::_16b64c4b,
ib::_24a2b, ib::_24b2c),
"unexpected inner_blk format");

// clang-format off
Expand Down Expand Up @@ -237,7 +240,9 @@ constexpr int AB_or_BC_blk_off(int x0, int x1) {
: (f == ib::_2b8a4b || f == ib::_2c8b4c) ? (x1 / 4) * 32 + x0 * 4 + x1 % 4
: (f == ib::_16b16a2b || f == ib::_16c16b2c) ? (x1 / 2) * 32 + x0 * 2 + x1 % 2
: (f == ib::_16b16a4b || f == ib::_16c16b4c) ? (x1 / 4) * 64 + x0 * 4 + x1 % 4
: (f == ib::_8b8a2b) ? (x1 / 2) * 16 + x0 * 2 + x1 % 2
: (f == ib::_8b16a2b || f == ib::_8c16b2c) ? (x1 / 2) * 32 + x0 * 2 + x1 % 2
: (f == ib::_8b24a2b) ? (x1 / 2) * 48 + x0 * 2 + x1 % 2
: (f == ib::_8b32a2b) ? (x1 / 2) * 64 + x0 * 2 + x1 % 2
: (f == ib::_8b64a2b) ? (x1 / 2) * 128 + x0 * 2 + x1 % 2
: (f == ib::_2b4c2b || f == ib::_2c4b2c) ? (x0 / 2) * 8 + x1 * 2 + x0 % 2
Expand Down Expand Up @@ -790,6 +795,15 @@ DECL_TRAITS(ABc8b16a, _AB, _8b16a, 3);
DECL_TRAITS(ABcd8b16a, _AB, _8b16a, 4);
DECL_TRAITS(ABcde8b16a, _AB, _8b16a, 5);
DECL_TRAITS(AB8b8a, _AB, _8b8a, 2);

DECL_TRAITS(AB8b8a2b, _AB, _8b8a2b, 2);
DECL_TRAITS(ABc8b8a2b, _AB, _8b8a2b, 3);
DECL_TRAITS(ABcd8b8a2b, _AB, _8b8a2b, 4);
DECL_TRAITS(ABcde8b8a2b, _AB, _8b8a2b, 5);
DECL_TRAITS(AB8b24a2b, _AB, _8b24a2b, 2);
DECL_TRAITS(ABc8b24a2b, _AB, _8b24a2b, 3);
DECL_TRAITS(ABcd8b24a2b, _AB, _8b24a2b, 4);
DECL_TRAITS(ABcde8b24a2b, _AB, _8b24a2b, 5);
} // namespace impl
} // namespace dnnl

Expand Down
8 changes: 7 additions & 1 deletion src/cpu/x64/jit_brgemm_inner_product_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,15 @@ std::unordered_map<int, format_tag_t> get_desired_weights_tag(
{32,
pick(n_sp_dims, OI8i32o2i, OIw8i32o2i, OIhw8i32o2i,
OIdhw8i32o2i)},
{24,
pick(n_sp_dims, OI8i24o2i, OIw8i24o2i, OIhw8i24o2i,
OIdhw8i24o2i)},
{16,
pick(n_sp_dims, OI8i16o2i, OIw8i16o2i, OIhw8i16o2i,
OIdhw8i16o2i)}};
OIdhw8i16o2i)},
{8,
pick(n_sp_dims, OI8i8o2i, OIw8i8o2i, OIhw8i8o2i,
OIdhw8i8o2i)}};
}
} else if (jbgp.wei_dt == data_type::s8) {
if (jbgp.is_amx) {
Expand Down
16 changes: 16 additions & 0 deletions tests/benchdnn/dnnl_debug_autogenerated.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,14 @@ dnnl_format_tag_t str2fmt_tag(const char *str) {
CASE(ABcd4b24a4b);
CASE(ABcde4b8a4b);
CASE(ABcde4b24a4b);
CASE(AB8b24a2b);
CASE(ABc8b24a2b);
CASE(ABcd8b24a2b);
CASE(ABcde8b24a2b);
CASE(AB8b8a2b);
CASE(ABc8b8a2b);
CASE(ABcd8b8a2b);
CASE(ABcde8b8a2b);
CASE(x);
CASE(nc);
CASE(cn);
Expand Down Expand Up @@ -859,7 +867,9 @@ dnnl_format_tag_t str2fmt_tag(const char *str) {
CASE(OI16i16o);
CASE(OI16i32o);
CASE(OI16i64o);
CASE(OI8i8o2i);
CASE(OI8i16o2i);
CASE(OI8i24o2i);
CASE(OI8i32o2i);
CASE(OI8i64o2i);
CASE(OI4i8o4i);
Expand Down Expand Up @@ -891,7 +901,9 @@ dnnl_format_tag_t str2fmt_tag(const char *str) {
CASE(OIw4i4o);
CASE(OIw4o4i);
CASE(Oiw4o);
CASE(OIw8i8o2i);
CASE(OIw8i16o2i);
CASE(OIw8i24o2i);
CASE(OIw8i32o2i);
CASE(OIw8i64o2i);
CASE(OIw8i8o);
Expand Down Expand Up @@ -940,8 +952,10 @@ dnnl_format_tag_t str2fmt_tag(const char *str) {
CASE(OIhw4i4o);
CASE(OIhw4o4i);
CASE(Oihw4o);
CASE(OIhw8i8o2i);
CASE(OIhw8i16o2i);
CASE(OIhw8i32o2i);
CASE(OIhw8i24o2i);
CASE(OIhw8i64o2i);
CASE(OIhw8i8o);
CASE(OIhw8o16i2o);
Expand Down Expand Up @@ -971,8 +985,10 @@ dnnl_format_tag_t str2fmt_tag(const char *str) {
CASE(OIdhw4i4o);
CASE(OIdhw4o4i);
CASE(Oidhw4o);
CASE(OIdhw8i8o2i);
CASE(OIdhw8i16o2i);
CASE(OIdhw8i32o2i);
CASE(OIdhw8i24o2i);
CASE(OIdhw8i64o2i);
CASE(OIdhw8i8o);
CASE(OIdhw8o16i2o);
Expand Down

0 comments on commit f8d7c2e

Please sign in to comment.