diff --git a/src/cpu/x64/matmul/brgemm_matmul_utils.cpp b/src/cpu/x64/matmul/brgemm_matmul_utils.cpp index 429c9271cc8..46f6f4d6fd4 100644 --- a/src/cpu/x64/matmul/brgemm_matmul_utils.cpp +++ b/src/cpu/x64/matmul/brgemm_matmul_utils.cpp @@ -47,18 +47,6 @@ using namespace dnnl::impl::utils; using namespace data_type; using namespace format_tag; -// Condition generated by decision tree classifier. -bool amx_xf16_is_small_shape(dim_t batch, dim_t M, dim_t K, dim_t N) { - const float b = static_cast(batch); - const float m = static_cast(M); - const float k = static_cast(K); - const float n = static_cast(N); - return (k <= 28 && n / k > 39.7) || (k <= 28 && n / k <= 8) - || (k > 28 && m * k <= 248.0 && n > 52.0) - || ((m * k * n) / b > 4862 && b * k >= 23 && n / k > 8 - && m * k * n <= 60817408); -} - int get_default_n_block(format_tag_t matrix_b_tag) { // Note: consider using weights mem_descriptor 'inner_blks' to // return B's inner block for non-default cases. @@ -1045,29 +1033,6 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc, bgmmc.bcast_B_desc.set_params( weights_d.dims(), dst_d.dims(), bgmmc.batch_ndims, bgmmc.batch); - // Dispatch small shapes to VNNI for better performance - const bool runtime_dims - = bgmmc.is_runtime_M || bgmmc.is_runtime_N || bgmmc.is_runtime_K; - - bool is_small_shapes = bgmmc.is_amx && !runtime_dims; - - // Disable 'small_shape' heuristic for amx_fp16 until it is validated with - // performance measurements. - is_small_shapes = is_small_shapes && (bgmmc.isa != avx512_core_amx_fp16); - - if (bm_conf_utils.is_bf16() || bm_conf_utils.is_f16()) { - is_small_shapes = is_small_shapes - && amx_xf16_is_small_shape( - bgmmc.batch, bgmmc.M, bgmmc.K, bgmmc.N); - } else { - is_small_shapes = is_small_shapes && bgmmc.ndims < 3 - && ((bgmmc.M == 1 && bgmmc.K == 256) - || (bgmmc.M <= 32 && bgmmc.M * bgmmc.N <= 256) - || bgmmc.K <= 16); - } - - VCONDCHECK_BG(!is_small_shapes, VERBOSE_SMALL_SHAPES); - // required granularity for k dimension bgmmc.required_k_granularity = bgmmc.is_amx ? data_type_vnni_granularity(bgmmc.wei_dt) : 1; @@ -1214,6 +1179,29 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc, init_aux_values(bgmmc, src_d, weights_d, dst_d); + // Dispatch small shapes to VNNI for better performance + const bool runtime_dims + = bgmmc.is_runtime_M || bgmmc.is_runtime_N || bgmmc.is_runtime_K; + + bool is_small_shapes = bgmmc.is_amx && !runtime_dims; + + // Disable 'small_shape' heuristic for amx_fp16 until it is validated with + // performance measurements. + is_small_shapes = is_small_shapes && (bgmmc.isa != avx512_core_amx_fp16); + + if (bm_conf_utils.is_bf16() || bm_conf_utils.is_f16()) { + // empirical observation for performance breakpoint between amx and vnni bf16/f16 + const dim_t buffer_a_chunk_sz_limit = 126; + is_small_shapes = is_small_shapes + && bgmmc.buffer_a_chunk_sz <= buffer_a_chunk_sz_limit; + } else { + is_small_shapes = is_small_shapes && bgmmc.ndims < 3 + && ((bgmmc.M == 1 && bgmmc.K == 256) + || (bgmmc.M <= 32 && bgmmc.M * bgmmc.N <= 256) + || bgmmc.K <= 16); + } + VCONDCHECK_BG(!is_small_shapes, VERBOSE_SMALL_SHAPES); + return status::success; }