Skip to content

Commit

Permalink
src: cpu: conv: Use acl_indirect_gemm for bf16 convolutions
Browse files Browse the repository at this point in the history
performance improvements:

Total benchdnn tests: 57
Min: 15x
Average: 131x
Max: 320x
  • Loading branch information
Ryo-not-rio authored and vpirogov committed Jun 24, 2024
1 parent 094cc1d commit 3a05ca5
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 5 deletions.
6 changes: 4 additions & 2 deletions src/cpu/aarch64/acl_convolution_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*******************************************************************************/

#include "cpu/aarch64/acl_convolution_utils.hpp"
#include "common/convolution_pd.hpp"
#include "common/utils.hpp"
#include "oneapi/dnnl/dnnl.h"

Expand Down Expand Up @@ -62,9 +63,10 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md,
everyone_is(data_type::f32, src_d.data_type(),
wei_d.data_type(), dst_d.data_type()),
everyone_is(data_type::f16, src_d.data_type(),
wei_d.data_type(), dst_d.data_type()),
everyone_is(data_type::bf16, src_d.data_type(),
wei_d.data_type(), dst_d.data_type())),
" src, dst and wei must be fp16 or fp32");

" src, dst and wei must be fp16, bf16 or fp32");
// batch size
const int mb = src_d.dims()[0];

Expand Down
7 changes: 5 additions & 2 deletions src/cpu/aarch64/acl_indirect_gemm_convolution.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2021-2023 Arm Ltd. and affiliates
* Copyright 2021-2024 Arm Ltd. and affiliates
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -84,12 +84,15 @@ struct acl_indirect_gemm_convolution_fwd_t : public primitive_t {

const bool is_fp16_ok = expect_data_types(f16, f16, f16, f16, undef)
&& attr()->has_default_values(smask_t::post_ops, f16);
const bool is_bf16_ok
= expect_data_types(bf16, bf16, bf16, bf16, undef)
&& attr_.post_ops_.len() == 0;
const bool is_fp32_ok = expect_data_types(f32, f32, f32, f32, undef)
&& attr()->has_default_values(
smask_t::post_ops | smask_t::fpmath_mode, f32);
bool ok = is_fwd()
&& set_default_alg_kind(alg_kind::convolution_direct)
&& utils::one_of(true, is_fp16_ok, is_fp32_ok)
&& utils::one_of(true, is_fp16_ok, is_bf16_ok, is_fp32_ok)
&& !has_zero_dim_memory();
if (!ok) return status::unimplemented;

Expand Down
3 changes: 2 additions & 1 deletion src/cpu/cpu_convolution_list.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*******************************************************************************
* Copyright 2019-2024 Intel Corporation
* Copyright 2020-2023 Arm Ltd. and affiliates
* Copyright 2020-2024 Arm Ltd. and affiliates
* Copyright 2020-2024 FUJITSU LIMITED
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -179,6 +179,7 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
CPU_INSTANCE_AVX512(gemm_bf16_convolution_fwd_t<bf16>)
CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t<avx2_vnni_2>)
CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t<avx2_vnni_2>)
CPU_INSTANCE_AARCH64_ACL(acl_indirect_gemm_convolution_fwd_t)
CPU_INSTANCE(ref_convolution_fwd_t)
CPU_INSTANCE(ref_fused_convolution_fwd_t)
nullptr,
Expand Down

0 comments on commit 3a05ca5

Please sign in to comment.