From b7694a00a26cfe3f0d9d9b36d16edac91bfdd65b Mon Sep 17 00:00:00 2001 From: David Svantesson-Yeung Date: Wed, 8 May 2024 10:23:45 +0000 Subject: [PATCH] src: cpu: aarch64: Fix deconv shape Fix an issue in deconv where ACL reduces the dimensions in TensorShape if the last dimension is of size 1. Also add unit test to catch this issue. --- src/cpu/aarch64/acl_deconvolution.hpp | 14 +++++++++----- tests/benchdnn/inputs/deconv/shapes_1d | 1 + 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/cpu/aarch64/acl_deconvolution.hpp b/src/cpu/aarch64/acl_deconvolution.hpp index 2bd6bbfb802..4b646148b1d 100644 --- a/src/cpu/aarch64/acl_deconvolution.hpp +++ b/src/cpu/aarch64/acl_deconvolution.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2022-2023 Arm Ltd. and affiliates +* Copyright 2022-2024 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -201,10 +201,14 @@ struct acl_deconvolution_fwd_t : public primitive_t { : arm_compute::TensorShape(iw, ih, ic, mb), 1, acl_src_data_t, acl_layout); - acl_pd_conf.wei_info = arm_compute::TensorInfo(is_nspc - ? arm_compute::TensorShape(ic, kw, kh, oc) - : arm_compute::TensorShape(kw, kh, ic, oc), - 1, acl_wei_data_t, acl_layout); + auto wei_info_tensor_shape = is_nspc + ? arm_compute::TensorShape(ic, kw, kh, oc) + : arm_compute::TensorShape(kw, kh, ic, oc); + // ACL removes last dimension if dim is 1. + // Below fix ensures the tensor shape is correct when queried. + wei_info_tensor_shape.set_num_dimensions(4); + acl_pd_conf.wei_info = arm_compute::TensorInfo( + wei_info_tensor_shape, 1, acl_wei_data_t, acl_layout); acl_pd_conf.dst_info = arm_compute::TensorInfo(is_nspc ? arm_compute::TensorShape(oc, ow, oh, mb) diff --git a/tests/benchdnn/inputs/deconv/shapes_1d b/tests/benchdnn/inputs/deconv/shapes_1d index ab517bb79b9..da7b830b1df 100644 --- a/tests/benchdnn/inputs/deconv/shapes_1d +++ b/tests/benchdnn/inputs/deconv/shapes_1d @@ -6,3 +6,4 @@ mb256oc256ow13ic384iw13kw3pw1n"alexnet:deconv3" g1mb96ic64iw112oc3ow224kw7sw2pw3n"googlenet_v1:conv1/7x7_s2" mb1_g1oc3ic64_ow1030iw512kw7sw2dw0pw0_n"masknet_p1:deconv1" g1mb50ic256iw28oc512ow56kw1sw2pw0n"resnet_50:res3a_branch1" +mb9_ic1oc1_ih1oh1kh1sh1dh0ph0_iw55ow55kw3sw1dw0pw1n"pytorch_unittest"