Skip to content

Commit

Permalink
cpu: x64: jit_reorder: restrict usage of prime numbers exceeding INT_MAX
Browse files Browse the repository at this point in the history
  • Loading branch information
dzarukin authored and karturov committed Mar 27, 2024
1 parent 4b72361 commit 69a111e
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 6 deletions.
28 changes: 24 additions & 4 deletions src/cpu/x64/jit_uni_reorder.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2018-2023 Intel Corporation
* Copyright 2018-2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -91,6 +91,14 @@ static bool prb_has_small_strides(const prb_t &prb) {
return true;
}

bool prb_has_huge_prime_number(const prb_t &prb) {
for (int d = 0; d < prb.ndims; ++d) {
auto n = prb.nodes[d].n;
if (n >= INT_MAX && math::is_prime(n)) return true;
}
return false;
}

/** Minimal reasonable/desirable kernel size.
* The constant might be used to determine how a problem should be split
* between kernel and threading driver. */
Expand Down Expand Up @@ -182,7 +190,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
static bool applicable(const prb_t &p) {
using namespace data_type;

bool ok = true && p.ndims > 0
bool ok = p.ndims > 0
&& utils::one_of(p.itype, f32, bf16, f16, s32, s8, u8)
&& utils::one_of(p.otype, f32, bf16, f16, s32, s8, u8)
&& IMPLICATION(utils::one_of(p.itype, bf16, f16),
Expand All @@ -196,7 +204,8 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
mayiuse(avx512_core) || mayiuse(avx2_vnni_2))
&& IMPLICATION(utils::one_of(f16, p.itype, p.otype),
mayiuse(avx512_core_fp16) || mayiuse(avx2_vnni_2))
&& IMPLICATION(!is_direct_copy(p), prb_has_small_strides(p));
&& IMPLICATION(!is_direct_copy(p), prb_has_small_strides(p))
&& !prb_has_huge_prime_number(p);
return ok;
}

Expand Down Expand Up @@ -500,7 +509,10 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
// TODO: make a standalone jit:direct_copy implementation.
const bool can_do = is_direct_copy(prb_)
// s8u8 with AVX should be used with XMM vreg.
&& IMPLICATION(is_i8 && isa_ == avx, !is_ymm);
&& IMPLICATION(is_i8 && isa_ == avx, !is_ymm)
// Prime numbers greater than INT_MAX cause input address
// overflow and crash.
&& !prb_has_huge_prime_number(prb_);
if (!can_do) return false;

const int tail_opmask_idx = 2;
Expand Down Expand Up @@ -2108,6 +2120,14 @@ static void prb_thread_kernel_balance(
* size_drv_cur. */
const bool want_borrow_drv_from_ker = size_ker_cur > tr::ker_prb_size_min
&& size_drv_cur < size_drv_min;

VDEBUGINFO(5, primitive, reorder,
"size_drv_thr=%zu size_drv_min=%zu size_drv_cur=%zu "
"tr::ker_prb_size_min=%zu want_borrow_ker_from_drv=%d "
"want_borrow_drv_from_ker=%d",
size_drv_thr, size_drv_min, size_drv_cur, tr::ker_prb_size_min,
want_borrow_ker_from_drv, want_borrow_drv_from_ker);

if (want_borrow_drv_from_ker) {
size_t size_want_borrow = utils::div_up(size_drv_min, size_drv_cur);
for (; prb.nodes[kdims - 1].n % size_want_borrow; ++size_want_borrow)
Expand Down
4 changes: 2 additions & 2 deletions src/cpu/x64/jit_uni_reorder_utils.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2018-2023 Intel Corporation
* Copyright 2018-2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -95,7 +95,7 @@ status_t cvt_mem_desc_to_layout_desc(const memory_desc_t &md_,
const int ld_ndims_start = ld.ndims;
if (blocks[d] != 1) {
stride_t stride = 1;
int tail = tails[d];
dim_t tail = tails[d];
for (int iblk = bd.inner_nblks - 1; iblk >= 0; --iblk) {
if (bd.inner_idxs[iblk] == d) {
const dim_t inner_tail = tail % bd.inner_blks[iblk];
Expand Down

0 comments on commit 69a111e

Please sign in to comment.