Skip to content

Commit

Permalink
src: cpu: aarch64: Fix deconv shape
Browse files Browse the repository at this point in the history
Fix an issue in deconv where ACL reduces the dimensions in TensorShape
if the last dimension is of size 1.
Also add unit test to catch this issue.
  • Loading branch information
davsva01 authored and vpirogov committed May 9, 2024
1 parent 307b35b commit b7694a0
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
14 changes: 9 additions & 5 deletions src/cpu/aarch64/acl_deconvolution.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2022-2023 Arm Ltd. and affiliates
* Copyright 2022-2024 Arm Ltd. and affiliates
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -201,10 +201,14 @@ struct acl_deconvolution_fwd_t : public primitive_t {
: arm_compute::TensorShape(iw, ih, ic, mb),
1, acl_src_data_t, acl_layout);

acl_pd_conf.wei_info = arm_compute::TensorInfo(is_nspc
? arm_compute::TensorShape(ic, kw, kh, oc)
: arm_compute::TensorShape(kw, kh, ic, oc),
1, acl_wei_data_t, acl_layout);
auto wei_info_tensor_shape = is_nspc
? arm_compute::TensorShape(ic, kw, kh, oc)
: arm_compute::TensorShape(kw, kh, ic, oc);
// ACL removes last dimension if dim is 1.
// Below fix ensures the tensor shape is correct when queried.
wei_info_tensor_shape.set_num_dimensions(4);
acl_pd_conf.wei_info = arm_compute::TensorInfo(
wei_info_tensor_shape, 1, acl_wei_data_t, acl_layout);

acl_pd_conf.dst_info = arm_compute::TensorInfo(is_nspc
? arm_compute::TensorShape(oc, ow, oh, mb)
Expand Down
1 change: 1 addition & 0 deletions tests/benchdnn/inputs/deconv/shapes_1d
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ mb256oc256ow13ic384iw13kw3pw1n"alexnet:deconv3"
g1mb96ic64iw112oc3ow224kw7sw2pw3n"googlenet_v1:conv1/7x7_s2"
mb1_g1oc3ic64_ow1030iw512kw7sw2dw0pw0_n"masknet_p1:deconv1"
g1mb50ic256iw28oc512ow56kw1sw2pw0n"resnet_50:res3a_branch1"
mb9_ic1oc1_ih1oh1kh1sh1dh0ph0_iw55ow55kw3sw1dw0pw1n"pytorch_unittest"

0 comments on commit b7694a0

Please sign in to comment.