From 69a111e6d835f8632ea571f3ea0e273b22488d37 Mon Sep 17 00:00:00 2001 From: Dmitrii Zarukin Date: Thu, 29 Feb 2024 12:07:14 -0800 Subject: [PATCH] cpu: x64: jit_reorder: restrict usage of prime numbers exceeding INT_MAX --- src/cpu/x64/jit_uni_reorder.cpp | 28 +++++++++++++++++++++++---- src/cpu/x64/jit_uni_reorder_utils.cpp | 4 ++-- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/src/cpu/x64/jit_uni_reorder.cpp b/src/cpu/x64/jit_uni_reorder.cpp index cd32a5bc407..7ac64c62992 100644 --- a/src/cpu/x64/jit_uni_reorder.cpp +++ b/src/cpu/x64/jit_uni_reorder.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2023 Intel Corporation +* Copyright 2018-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -91,6 +91,14 @@ static bool prb_has_small_strides(const prb_t &prb) { return true; } +bool prb_has_huge_prime_number(const prb_t &prb) { + for (int d = 0; d < prb.ndims; ++d) { + auto n = prb.nodes[d].n; + if (n >= INT_MAX && math::is_prime(n)) return true; + } + return false; +} + /** Minimal reasonable/desirable kernel size. * The constant might be used to determine how a problem should be split * between kernel and threading driver. */ @@ -182,7 +190,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { static bool applicable(const prb_t &p) { using namespace data_type; - bool ok = true && p.ndims > 0 + bool ok = p.ndims > 0 && utils::one_of(p.itype, f32, bf16, f16, s32, s8, u8) && utils::one_of(p.otype, f32, bf16, f16, s32, s8, u8) && IMPLICATION(utils::one_of(p.itype, bf16, f16), @@ -196,7 +204,8 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { mayiuse(avx512_core) || mayiuse(avx2_vnni_2)) && IMPLICATION(utils::one_of(f16, p.itype, p.otype), mayiuse(avx512_core_fp16) || mayiuse(avx2_vnni_2)) - && IMPLICATION(!is_direct_copy(p), prb_has_small_strides(p)); + && IMPLICATION(!is_direct_copy(p), prb_has_small_strides(p)) + && !prb_has_huge_prime_number(p); return ok; } @@ -500,7 +509,10 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { // TODO: make a standalone jit:direct_copy implementation. const bool can_do = is_direct_copy(prb_) // s8u8 with AVX should be used with XMM vreg. - && IMPLICATION(is_i8 && isa_ == avx, !is_ymm); + && IMPLICATION(is_i8 && isa_ == avx, !is_ymm) + // Prime numbers greater than INT_MAX cause input address + // overflow and crash. + && !prb_has_huge_prime_number(prb_); if (!can_do) return false; const int tail_opmask_idx = 2; @@ -2108,6 +2120,14 @@ static void prb_thread_kernel_balance( * size_drv_cur. */ const bool want_borrow_drv_from_ker = size_ker_cur > tr::ker_prb_size_min && size_drv_cur < size_drv_min; + + VDEBUGINFO(5, primitive, reorder, + "size_drv_thr=%zu size_drv_min=%zu size_drv_cur=%zu " + "tr::ker_prb_size_min=%zu want_borrow_ker_from_drv=%d " + "want_borrow_drv_from_ker=%d", + size_drv_thr, size_drv_min, size_drv_cur, tr::ker_prb_size_min, + want_borrow_ker_from_drv, want_borrow_drv_from_ker); + if (want_borrow_drv_from_ker) { size_t size_want_borrow = utils::div_up(size_drv_min, size_drv_cur); for (; prb.nodes[kdims - 1].n % size_want_borrow; ++size_want_borrow) diff --git a/src/cpu/x64/jit_uni_reorder_utils.cpp b/src/cpu/x64/jit_uni_reorder_utils.cpp index 9974d69c213..f14805012f4 100644 --- a/src/cpu/x64/jit_uni_reorder_utils.cpp +++ b/src/cpu/x64/jit_uni_reorder_utils.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2023 Intel Corporation +* Copyright 2018-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -95,7 +95,7 @@ status_t cvt_mem_desc_to_layout_desc(const memory_desc_t &md_, const int ld_ndims_start = ld.ndims; if (blocks[d] != 1) { stride_t stride = 1; - int tail = tails[d]; + dim_t tail = tails[d]; for (int iblk = bd.inner_nblks - 1; iblk >= 0; --iblk) { if (bd.inner_idxs[iblk] == d) { const dim_t inner_tail = tail % bd.inner_blks[iblk];