From 58457600b0180d15132119f963c1a03b46b86ab7 Mon Sep 17 00:00:00 2001 From: XuZhi Date: Mon, 28 Mar 2022 15:58:31 +0800 Subject: [PATCH 01/14] [BYOC][ACL] Fix list is not supported as an input node --- .../tvm/relay/op/contrib/arm_compute_lib.py | 1 + .../contrib/arm_compute_lib/acl_runtime.cc | 67 +++++++++- .../contrib/arm_compute_lib/acl_utils.cc | 11 +- .../contrib/arm_compute_lib/acl_utils.h | 8 +- src/runtime/contrib/json/json_runtime.h | 1 + .../test_arm_compute_lib/test_concate.py | 126 ++++++++++++++++++ 6 files changed, 200 insertions(+), 14 deletions(-) create mode 100644 tests/python/contrib/test_arm_compute_lib/test_concate.py diff --git a/python/tvm/relay/op/contrib/arm_compute_lib.py b/python/tvm/relay/op/contrib/arm_compute_lib.py index 9f3c1cdec0f7..9ddbfb24573d 100644 --- a/python/tvm/relay/op/contrib/arm_compute_lib.py +++ b/python/tvm/relay/op/contrib/arm_compute_lib.py @@ -286,6 +286,7 @@ def _func_wrapper(expr): _register_external_op_helper("reshape") +_register_external_op_helper("concatenate") @tvm.ir.register_op_attr("nn.conv2d", "target.arm_compute_lib") diff --git a/src/runtime/contrib/arm_compute_lib/acl_runtime.cc b/src/runtime/contrib/arm_compute_lib/acl_runtime.cc index a336cf494f4b..190b833339df 100644 --- a/src/runtime/contrib/arm_compute_lib/acl_runtime.cc +++ b/src/runtime/contrib/arm_compute_lib/acl_runtime.cc @@ -31,6 +31,7 @@ #ifdef TVM_GRAPH_EXECUTOR_ARM_COMPUTE_LIB #include #include +#include #include #include #include @@ -93,10 +94,19 @@ class ACLRuntime : public JSONRuntimeBase { void Run() override { for (size_t i = 0; i < input_nodes_.size(); ++i) { auto nid = input_nodes_[i]; - uint32_t eid = EntryID(nid, 0); if (nodes_[nid].GetOpType() == "input") { - void* data = data_entry_[eid]->data; - CheckACLError(layer_.inputs[i].allocator()->import_memory(data)); + for (int index = 0; index < nodes_[nid].GetNumOutput(); index++) { + uint32_t eid = EntryID(nid, index); + void* data = data_entry_[eid]->data; + auto key = std::pair(nid, index); + if (layer_.json_inputid_to_layer_inputid.count(key) > 0) { + CheckACLError( + layer_.inputs[layer_.json_inputid_to_layer_inputid[key]].allocator()->import_memory( + data)); + } else { + CheckACLError(layer_.inputs[i].allocator()->import_memory(data)); + } + } } } @@ -149,6 +159,8 @@ class ACLRuntime : public JSONRuntimeBase { CreateMaximumLayer(&layer_, node); } else if ("add" == op_name || "qnn.add" == op_name) { CreateAddLayer(&layer_, node); + } else if ("concatenate" == op_name) { + CreateConcatenateLayer(&layer_, node); } else { LOG(FATAL) << "Unsupported op: " << op_name; } @@ -166,6 +178,7 @@ class ACLRuntime : public JSONRuntimeBase { std::shared_ptr function; std::vector inputs; std::vector outputs; + std::map, uint32_t> json_inputid_to_layer_inputid; }; /*! @@ -175,17 +188,25 @@ class ACLRuntime : public JSONRuntimeBase { * \param tensor The tensor to represent. * \param scale (optional) The scale of the tensor as an input. * \param offset (optional) The offset of the tensor as an input. + * \param apply_dim_correction (Optional) Flag to state whether apply dimension correction after + * setting one dimension. E.g. when permuting NCHW -> NHWC, 1x1x2 would become 2x1x1, but + * _num_dimensions should be 3 rather than 1. + * \param increase_dim_unit (Optional) Set to true if new unit dimensions increase the number of + * dimensions of the shape. * \return ACL Tensor. */ arm_compute::Tensor MakeACLTensorFromJSONEntry(const JSONGraphNodeEntry& tensor, JSONGraphNodeEntry* scale = nullptr, - JSONGraphNodeEntry* offset = nullptr) { + JSONGraphNodeEntry* offset = nullptr, + bool apply_dim_correction = true, + bool increase_dim_unit = true) { JSONGraphNode node = nodes_[tensor.id_]; void* node_data = nullptr; if (node.GetOpType() == "const") { node_data = data_entry_[EntryID(tensor)]->data; } - return MakeACLTensorFromJSONNode(node, scale, offset, node_data); + return MakeACLTensorFromJSONNode(node, scale, offset, node_data, apply_dim_correction, + increase_dim_unit); } /*! @@ -196,19 +217,27 @@ class ACLRuntime : public JSONRuntimeBase { * \param scale (optional) The scale of the tensor as an input. * \param offset (optional) The offset of the tensor as an input. * \param data (optional) Constant data of input node. + * \param apply_dim_correction (Optional) Flag to state whether apply dimension correction after + * setting one dimension. E.g. when permuting NCHW -> NHWC, 1x1x2 would become 2x1x1, but + * _num_dimensions should be 3 rather than 1. + * \param increase_dim_unit (Optional) Set to true if new unit dimensions increase the number of + * dimensions of the shape. * \return ACL Tensor. */ arm_compute::Tensor MakeACLTensorFromJSONNode(const JSONGraphNode& node, JSONGraphNodeEntry* scale = nullptr, JSONGraphNodeEntry* offset = nullptr, - void* data = nullptr) { + void* data = nullptr, + bool apply_dim_correction = true, + bool increase_dim_unit = true) { const DLTensor* scale_data = nullptr; const DLTensor* offset_data = nullptr; if (scale && offset) { scale_data = data_entry_[EntryID(*scale)]; offset_data = data_entry_[EntryID(*offset)]; } - return MakeACLTensor(node, data, scale_data, offset_data); + return MakeACLTensor(node, data, scale_data, offset_data, apply_dim_correction, + increase_dim_unit); } /*! @@ -510,6 +539,30 @@ class ACLRuntime : public JSONRuntimeBase { layer->function = f; } + /*! + * \brief Create a Concatenate layer. + * + * \param layer The ACL layer to build. Containing inputs, outputs and the ACL function.c + * \param node The JSON representation of the operator. + */ + void CreateConcatenateLayer(CachedLayer* layer, const JSONGraphNode& node) { + std::vector axis = node.GetAttr>("axis"); + std::vector inputs; + for (auto input : node.GetInputs()) { + layer->inputs.push_back(MakeACLTensorFromJSONEntry(input, nullptr, nullptr, false)); + layer->json_inputid_to_layer_inputid[std::pair(input.id_, input.index_)] = + layer->inputs.size() - 1; + } + for (size_t i = 0; i < layer->inputs.size(); i++) { + inputs.push_back(&layer->inputs[i]); + } + layer->outputs.push_back(MakeACLTensorFromJSONNode(node)); + int dimNum = layer->inputs[0].info()->num_dimensions(); + auto function = std::make_shared(); + function->configure(inputs, &layer->outputs[0], dimNum - std::stoi(axis[0]) - 1); + layer->function = function; + } + /*! \brief Allow ACL functions to request auxiliary memory from TVM. */ ACLAllocator allocator_; /*! diff --git a/src/runtime/contrib/arm_compute_lib/acl_utils.cc b/src/runtime/contrib/arm_compute_lib/acl_utils.cc index 3b2620987ab0..be288fbd1da7 100644 --- a/src/runtime/contrib/arm_compute_lib/acl_utils.cc +++ b/src/runtime/contrib/arm_compute_lib/acl_utils.cc @@ -40,11 +40,13 @@ void CheckACLError(const arm_compute::Status& status) { } arm_compute::Tensor MakeACLTensor(const JSONGraphNode& tensor_rep, void* data, - const DLTensor* scale, const DLTensor* offset) { + const DLTensor* scale, const DLTensor* offset, + bool apply_dim_correction, bool increase_dim_unit) { arm_compute::Tensor tensor; std::vector shape = tensor_rep.GetOpShape()[0]; DLDataType dtype = tensor_rep.GetOpDataType()[0]; - arm_compute::TensorInfo info = MakeACLTensorInfo(shape, dtype, scale, offset); + arm_compute::TensorInfo info = + MakeACLTensorInfo(shape, dtype, scale, offset, apply_dim_correction, increase_dim_unit); info.set_is_resizable(false); tensor.allocator()->init(info); if (data != nullptr) { @@ -55,10 +57,11 @@ arm_compute::Tensor MakeACLTensor(const JSONGraphNode& tensor_rep, void* data, arm_compute::TensorInfo MakeACLTensorInfo(const std::vector& shape, const DLDataType& dtype, const DLTensor* scale, - const DLTensor* offset) { + const DLTensor* offset, bool apply_dim_correction, + bool increase_dim_unit) { arm_compute::TensorShape acl_shape; for (unsigned int i = shape.size(); i > 0; --i) { - acl_shape.set(shape.size() - i, shape[i - 1]); + acl_shape.set(shape.size() - i, shape[i - 1], apply_dim_correction, increase_dim_unit); } arm_compute::DataType acl_dtype = MakeACLDataType(dtype); arm_compute::TensorInfo info(acl_shape, 1, acl_dtype, arm_compute::DataLayout::NHWC); diff --git a/src/runtime/contrib/arm_compute_lib/acl_utils.h b/src/runtime/contrib/arm_compute_lib/acl_utils.h index dbb006fbb347..8225c91da2f4 100644 --- a/src/runtime/contrib/arm_compute_lib/acl_utils.h +++ b/src/runtime/contrib/arm_compute_lib/acl_utils.h @@ -63,8 +63,8 @@ void CheckACLError(const arm_compute::Status& status); * \return arm_compute::Tensor. */ arm_compute::Tensor MakeACLTensor(const JSONGraphNode& tensor_rep, void* data = nullptr, - const DLTensor* scale = nullptr, - const DLTensor* offset = nullptr); + const DLTensor* scale = nullptr, const DLTensor* offset = nullptr, + bool apply_dim_correction = true, bool increase_dim_unit = true); /*! * \brief Make an acl tensor info object from JSON tensor @@ -78,7 +78,9 @@ arm_compute::Tensor MakeACLTensor(const JSONGraphNode& tensor_rep, void* data = */ arm_compute::TensorInfo MakeACLTensorInfo(const std::vector& shape, const DLDataType& dtype, const DLTensor* scale = nullptr, - const DLTensor* offset = nullptr); + const DLTensor* offset = nullptr, + bool apply_dim_correction = true, + bool increase_dim_unit = true); /*! * \brief Create a memory manager for use with a layer that diff --git a/src/runtime/contrib/json/json_runtime.h b/src/runtime/contrib/json/json_runtime.h index 1735d8569215..b122a435de40 100644 --- a/src/runtime/contrib/json/json_runtime.h +++ b/src/runtime/contrib/json/json_runtime.h @@ -186,6 +186,7 @@ class JSONRuntimeBase : public ModuleNode { for (size_t j = 0; j < nodes_[nid].GetOpShape().size(); ++j) { input_var_eid_.push_back(EntryID(nid, j)); } + nodes_[nid].SetNumOutput(nodes_[nid].GetOpShape().size()); } else { ICHECK_EQ(nodes_[nid].op_type_, "const"); auto pos = std::find(std::begin(const_names_), std::end(const_names_), name); diff --git a/tests/python/contrib/test_arm_compute_lib/test_concate.py b/tests/python/contrib/test_arm_compute_lib/test_concate.py new file mode 100644 index 000000000000..1cfc8c42e0f4 --- /dev/null +++ b/tests/python/contrib/test_arm_compute_lib/test_concate.py @@ -0,0 +1,126 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Arm Compute Library integration space_to_batch_nd tests.""" + +import numpy as np + +import tvm +from tvm import relay +from tvm import testing + +from test_arm_compute_libinfrastructure import ( + skip_runtime_test, + skip_codegen_test, + build_and_run, + verify, + verify_codegen, +) +from test_arm_compute_libinfrastructure import Device + + +def _get_model(input_shape_a, input_shape_b, input_shape_c, axis, dtype, var_names): + """Return a model and any parameters it may have.""" + a = relay.var(next(var_names), shape=input_shape_a, dtype=dtype) + b = relay.var(next(var_names), shape=input_shape_b, dtype=dtype) + c = relay.var(next(var_names), shape=input_shape_c, dtype=dtype) + out = relay.concatenate([a, b, c], axis) + return out + + +def _get_expected_codegen(input_shape_a, input_shape_b, input_shape_c, axis, dtype): + node = { + "op": "kernel", + "name": "concatenate", + "inputs": [ + [0, 0, 0], + [0, 1, 0], + [0, 2, 0], + ], + "attrs": { + "num_outputs": "1", + "num_inputs": "3", + "dtype": [[dtype]], + "axis": [[str(axis)]], + "shape": [[[3, 234, 234, 256]]], + }, + } + + input = { + "op": "input", + "name": "", + "attrs": { + "shape": [[input_shape_a, input_shape_b, input_shape_c]], + "dtype": [[dtype, dtype, dtype]], + }, + } + + return [input, node] + + +def test_concatenate(): + Device.load("test_config.json") + + if skip_runtime_test(): + return + + device = Device() + np.random.seed(0) + + for input_shape_a, input_shape_b, input_shape_c, axis in [ + ([1, 234, 234, 256], [1, 234, 234, 256], [1, 234, 234, 256], 0), + ]: + dtype = "int32" + outputs = [] + inputs = { + "a": tvm.nd.array(np.random.randn(*input_shape_a).astype(dtype)), + "b": tvm.nd.array(np.random.randn(*input_shape_b).astype(dtype)), + "c": tvm.nd.array(np.random.randn(*input_shape_c).astype(dtype)), + } + func = _get_model( + inputs["a"].shape, inputs["b"].shape, inputs["c"].shape, axis, dtype, iter(inputs) + ) + for acl in [False, True]: + outputs.append(build_and_run(func, inputs, 1, None, device, enable_acl=acl)[0]) + + config = { + "input_shape_a": input_shape_a, + "input_shape_b": input_shape_b, + "input_shape_c": input_shape_c, + "axis": 0, + "dtype": dtype, + } + verify(outputs, atol=1e-7, rtol=1e-7, config=config) + + +def test_codegen_concatenate(): + if skip_codegen_test(): + return + shape_a = [1, 234, 234, 256] + shape_b = [1, 234, 234, 256] + shape_c = [1, 234, 234, 256] + axis = 0 + inputs = {"a", "b", "c"} + for dtype in ["float32"]: + args = (shape_a, shape_b, shape_c, axis, dtype) + func = _get_model(*args, iter(inputs)) + exp_codegen = _get_expected_codegen(*args) + verify_codegen(func, exp_codegen, 1) + + +if __name__ == "__main__": + test_concatenate() + test_codegen_concatenate() From 937e91afc77c24bcf3a3b82ae7a127a6f29b6802 Mon Sep 17 00:00:00 2001 From: XuZhi Date: Mon, 28 Mar 2022 16:32:45 +0800 Subject: [PATCH 02/14] fix clang lint error --- src/runtime/contrib/json/json_runtime.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/runtime/contrib/json/json_runtime.h b/src/runtime/contrib/json/json_runtime.h index b122a435de40..0c6d0f6d7136 100644 --- a/src/runtime/contrib/json/json_runtime.h +++ b/src/runtime/contrib/json/json_runtime.h @@ -186,7 +186,7 @@ class JSONRuntimeBase : public ModuleNode { for (size_t j = 0; j < nodes_[nid].GetOpShape().size(); ++j) { input_var_eid_.push_back(EntryID(nid, j)); } - nodes_[nid].SetNumOutput(nodes_[nid].GetOpShape().size()); + nodes_[nid].SetNumOutput(nodes_[nid].GetOpShape().size()); } else { ICHECK_EQ(nodes_[nid].op_type_, "const"); auto pos = std::find(std::begin(const_names_), std::end(const_names_), name); From b47f26a940e2b58548e0f2be4a3d6bfda737413f Mon Sep 17 00:00:00 2001 From: XuZhi Date: Mon, 28 Mar 2022 18:38:25 +0800 Subject: [PATCH 03/14] fix compile warnning --- src/runtime/contrib/arm_compute_lib/acl_runtime.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/runtime/contrib/arm_compute_lib/acl_runtime.cc b/src/runtime/contrib/arm_compute_lib/acl_runtime.cc index 190b833339df..a06640edcc23 100644 --- a/src/runtime/contrib/arm_compute_lib/acl_runtime.cc +++ b/src/runtime/contrib/arm_compute_lib/acl_runtime.cc @@ -95,7 +95,7 @@ class ACLRuntime : public JSONRuntimeBase { for (size_t i = 0; i < input_nodes_.size(); ++i) { auto nid = input_nodes_[i]; if (nodes_[nid].GetOpType() == "input") { - for (int index = 0; index < nodes_[nid].GetNumOutput(); index++) { + for (uint32_t index = 0; index < nodes_[nid].GetNumOutput(); index++) { uint32_t eid = EntryID(nid, index); void* data = data_entry_[eid]->data; auto key = std::pair(nid, index); From 4089f17812437a36f4aeac39595df07abb574693 Mon Sep 17 00:00:00 2001 From: XuZhi Date: Mon, 28 Mar 2022 19:58:05 +0800 Subject: [PATCH 04/14] fix python module import error --- tests/python/contrib/test_arm_compute_lib/test_concate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/contrib/test_arm_compute_lib/test_concate.py b/tests/python/contrib/test_arm_compute_lib/test_concate.py index 1cfc8c42e0f4..e7e5618c1e20 100644 --- a/tests/python/contrib/test_arm_compute_lib/test_concate.py +++ b/tests/python/contrib/test_arm_compute_lib/test_concate.py @@ -22,14 +22,14 @@ from tvm import relay from tvm import testing -from test_arm_compute_libinfrastructure import ( +from test_arm_compute_lib.infrastructure import ( skip_runtime_test, skip_codegen_test, build_and_run, verify, verify_codegen, ) -from test_arm_compute_libinfrastructure import Device +from test_arm_compute_lib.infrastructure import Device def _get_model(input_shape_a, input_shape_b, input_shape_c, axis, dtype, var_names): From 1d3aebcfe7ce1319cbd311e1d26d09faad299586 Mon Sep 17 00:00:00 2001 From: XuZhi Date: Tue, 29 Mar 2022 18:32:14 +0800 Subject: [PATCH 05/14] rename concatenate test file --- .../test_arm_compute_lib/{test_concate.py => test_concatenate.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/python/contrib/test_arm_compute_lib/{test_concate.py => test_concatenate.py} (100%) diff --git a/tests/python/contrib/test_arm_compute_lib/test_concate.py b/tests/python/contrib/test_arm_compute_lib/test_concatenate.py similarity index 100% rename from tests/python/contrib/test_arm_compute_lib/test_concate.py rename to tests/python/contrib/test_arm_compute_lib/test_concatenate.py From d89b7d493a31af4e2e6231750a3e3ad8ed4addf9 Mon Sep 17 00:00:00 2001 From: XuZhi Date: Thu, 31 Mar 2022 18:17:15 +0800 Subject: [PATCH 06/14] fix always MakeACLTensor with same eid 0 --- .../tvm/relay/op/contrib/arm_compute_lib.py | 14 +++++++- .../contrib/arm_compute_lib/acl_runtime.cc | 35 +++++++++++-------- .../contrib/arm_compute_lib/acl_utils.cc | 7 ++-- .../contrib/arm_compute_lib/acl_utils.h | 3 +- .../test_arm_compute_lib/test_concatenate.py | 22 +++++++----- 5 files changed, 53 insertions(+), 28 deletions(-) diff --git a/python/tvm/relay/op/contrib/arm_compute_lib.py b/python/tvm/relay/op/contrib/arm_compute_lib.py index 9ddbfb24573d..7c7a10e0c849 100644 --- a/python/tvm/relay/op/contrib/arm_compute_lib.py +++ b/python/tvm/relay/op/contrib/arm_compute_lib.py @@ -286,7 +286,6 @@ def _func_wrapper(expr): _register_external_op_helper("reshape") -_register_external_op_helper("concatenate") @tvm.ir.register_op_attr("nn.conv2d", "target.arm_compute_lib") @@ -491,6 +490,19 @@ def qnn_add(expr): return True +@tvm.ir.register_op_attr("concatenate", "target.arm_compute_lib") +def concatenate(expr): + """Check if the external ACL codegen for concatenate should be used.""" + attrs, type_args = expr.attrs, expr.type_args + for idx in range(len(type_args[0].fields)): + if type_args[0].fields[idx].dtype not in ["float32", "uint8"]: + return False + # ACL concatenate only supports maximum 4 dimensions input tensor + if attrs.axis not in [-4, -3, -2, -1, 0, 1, 2, 3]: + return False + return True + + class OpAttrContext(object): """Temporarily changes the attr of an op.""" diff --git a/src/runtime/contrib/arm_compute_lib/acl_runtime.cc b/src/runtime/contrib/arm_compute_lib/acl_runtime.cc index a06640edcc23..5687e687cfb6 100644 --- a/src/runtime/contrib/arm_compute_lib/acl_runtime.cc +++ b/src/runtime/contrib/arm_compute_lib/acl_runtime.cc @@ -92,19 +92,19 @@ class ACLRuntime : public JSONRuntimeBase { * \return Status of inference. */ void Run() override { - for (size_t i = 0; i < input_nodes_.size(); ++i) { - auto nid = input_nodes_[i]; + for (size_t nid_idx = 0; nid_idx < input_nodes_.size(); ++nid_idx) { + auto nid = input_nodes_[nid_idx]; if (nodes_[nid].GetOpType() == "input") { - for (uint32_t index = 0; index < nodes_[nid].GetNumOutput(); index++) { - uint32_t eid = EntryID(nid, index); + for (uint32_t eid_idx = 0; eid_idx < nodes_[nid].GetNumOutput(); eid_idx++) { + uint32_t eid = EntryID(nid, eid_idx); void* data = data_entry_[eid]->data; - auto key = std::pair(nid, index); + auto key = std::pair(nid, eid_idx); if (layer_.json_inputid_to_layer_inputid.count(key) > 0) { CheckACLError( layer_.inputs[layer_.json_inputid_to_layer_inputid[key]].allocator()->import_memory( data)); } else { - CheckACLError(layer_.inputs[i].allocator()->import_memory(data)); + CheckACLError(layer_.inputs[nid_idx].allocator()->import_memory(data)); } } } @@ -178,6 +178,8 @@ class ACLRuntime : public JSONRuntimeBase { std::shared_ptr function; std::vector inputs; std::vector outputs; + // maps the input index of JSON node to the index of the ACL layer's inputs + // this is optional (i.e.only when an operator uses the eid index) std::map, uint32_t> json_inputid_to_layer_inputid; }; @@ -206,7 +208,7 @@ class ACLRuntime : public JSONRuntimeBase { node_data = data_entry_[EntryID(tensor)]->data; } return MakeACLTensorFromJSONNode(node, scale, offset, node_data, apply_dim_correction, - increase_dim_unit); + increase_dim_unit, tensor.index_); } /*! @@ -222,14 +224,13 @@ class ACLRuntime : public JSONRuntimeBase { * _num_dimensions should be 3 rather than 1. * \param increase_dim_unit (Optional) Set to true if new unit dimensions increase the number of * dimensions of the shape. + * \param entry_index The entry index. * \return ACL Tensor. */ - arm_compute::Tensor MakeACLTensorFromJSONNode(const JSONGraphNode& node, - JSONGraphNodeEntry* scale = nullptr, - JSONGraphNodeEntry* offset = nullptr, - void* data = nullptr, - bool apply_dim_correction = true, - bool increase_dim_unit = true) { + arm_compute::Tensor MakeACLTensorFromJSONNode( + const JSONGraphNode& node, JSONGraphNodeEntry* scale = nullptr, + JSONGraphNodeEntry* offset = nullptr, void* data = nullptr, bool apply_dim_correction = true, + bool increase_dim_unit = true, uint32_t entry_index = 0) { const DLTensor* scale_data = nullptr; const DLTensor* offset_data = nullptr; if (scale && offset) { @@ -237,7 +238,7 @@ class ACLRuntime : public JSONRuntimeBase { offset_data = data_entry_[EntryID(*offset)]; } return MakeACLTensor(node, data, scale_data, offset_data, apply_dim_correction, - increase_dim_unit); + increase_dim_unit, entry_index); } /*! @@ -559,7 +560,11 @@ class ACLRuntime : public JSONRuntimeBase { layer->outputs.push_back(MakeACLTensorFromJSONNode(node)); int dimNum = layer->inputs[0].info()->num_dimensions(); auto function = std::make_shared(); - function->configure(inputs, &layer->outputs[0], dimNum - std::stoi(axis[0]) - 1); + // the shape of input tensor will be reversed after passing to ACL + // for example a tensor with shape [1, 2, 3, 4] will be changed to + // [4, 3, 2, 1] at ACL side. So the axis here should be preprocessed. + auto a = std::stoi(axis[0]); + function->configure(inputs, &layer->outputs[0], a < 0 ? -a - 1 : dimNum - a - 1); layer->function = function; } diff --git a/src/runtime/contrib/arm_compute_lib/acl_utils.cc b/src/runtime/contrib/arm_compute_lib/acl_utils.cc index be288fbd1da7..238b7355de26 100644 --- a/src/runtime/contrib/arm_compute_lib/acl_utils.cc +++ b/src/runtime/contrib/arm_compute_lib/acl_utils.cc @@ -41,10 +41,11 @@ void CheckACLError(const arm_compute::Status& status) { arm_compute::Tensor MakeACLTensor(const JSONGraphNode& tensor_rep, void* data, const DLTensor* scale, const DLTensor* offset, - bool apply_dim_correction, bool increase_dim_unit) { + bool apply_dim_correction, bool increase_dim_unit, + uint32_t entry_index) { arm_compute::Tensor tensor; - std::vector shape = tensor_rep.GetOpShape()[0]; - DLDataType dtype = tensor_rep.GetOpDataType()[0]; + std::vector shape = tensor_rep.GetOpShape()[entry_index]; + DLDataType dtype = tensor_rep.GetOpDataType()[entry_index]; arm_compute::TensorInfo info = MakeACLTensorInfo(shape, dtype, scale, offset, apply_dim_correction, increase_dim_unit); info.set_is_resizable(false); diff --git a/src/runtime/contrib/arm_compute_lib/acl_utils.h b/src/runtime/contrib/arm_compute_lib/acl_utils.h index 8225c91da2f4..a553839240e4 100644 --- a/src/runtime/contrib/arm_compute_lib/acl_utils.h +++ b/src/runtime/contrib/arm_compute_lib/acl_utils.h @@ -64,7 +64,8 @@ void CheckACLError(const arm_compute::Status& status); */ arm_compute::Tensor MakeACLTensor(const JSONGraphNode& tensor_rep, void* data = nullptr, const DLTensor* scale = nullptr, const DLTensor* offset = nullptr, - bool apply_dim_correction = true, bool increase_dim_unit = true); + bool apply_dim_correction = true, bool increase_dim_unit = true, + uint32_t entry_index = 0); /*! * \brief Make an acl tensor info object from JSON tensor diff --git a/tests/python/contrib/test_arm_compute_lib/test_concatenate.py b/tests/python/contrib/test_arm_compute_lib/test_concatenate.py index e7e5618c1e20..86adf6e0c33e 100644 --- a/tests/python/contrib/test_arm_compute_lib/test_concatenate.py +++ b/tests/python/contrib/test_arm_compute_lib/test_concatenate.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Arm Compute Library integration space_to_batch_nd tests.""" +"""Arm Compute Library integration concatenate tests.""" import numpy as np @@ -55,7 +55,7 @@ def _get_expected_codegen(input_shape_a, input_shape_b, input_shape_c, axis, dty "num_inputs": "3", "dtype": [[dtype]], "axis": [[str(axis)]], - "shape": [[[3, 234, 234, 256]]], + "shape": [[[6, 234, 234, 256]]], }, } @@ -80,10 +80,16 @@ def test_concatenate(): device = Device() np.random.seed(0) - for input_shape_a, input_shape_b, input_shape_c, axis in [ - ([1, 234, 234, 256], [1, 234, 234, 256], [1, 234, 234, 256], 0), + for input_shape_a, input_shape_b, input_shape_c, axis, dtype in [ + ([1, 234, 234, 256], [2, 234, 234, 256], [3, 234, 234, 256], 0, "float32"), + ([1, 1, 234, 256], [1, 2, 234, 256], [1, 3, 234, 256], 1, "float32"), + ([1, 234, 234, 1], [1, 234, 234, 2], [1, 234, 234, 3], -1, "float32"), + ([1, 234, 234, 256], [2, 234, 234, 256], [3, 234, 234, 256], -4, "float32"), + ([1, 234, 234, 256], [2, 234, 234, 256], [3, 234, 234, 256], 0, "uint8"), + ([1, 1, 234, 256], [1, 2, 234, 256], [1, 3, 234, 256], 1, "uint8"), + ([1, 234, 234, 1], [1, 234, 234, 2], [1, 234, 234, 3], -1, "uint8"), + ([1, 234, 234, 256], [2, 234, 234, 256], [3, 234, 234, 256], -4, "uint8"), ]: - dtype = "int32" outputs = [] inputs = { "a": tvm.nd.array(np.random.randn(*input_shape_a).astype(dtype)), @@ -100,7 +106,7 @@ def test_concatenate(): "input_shape_a": input_shape_a, "input_shape_b": input_shape_b, "input_shape_c": input_shape_c, - "axis": 0, + "axis": axis, "dtype": dtype, } verify(outputs, atol=1e-7, rtol=1e-7, config=config) @@ -110,8 +116,8 @@ def test_codegen_concatenate(): if skip_codegen_test(): return shape_a = [1, 234, 234, 256] - shape_b = [1, 234, 234, 256] - shape_c = [1, 234, 234, 256] + shape_b = [2, 234, 234, 256] + shape_c = [3, 234, 234, 256] axis = 0 inputs = {"a", "b", "c"} for dtype in ["float32"]: From 8c87d4571eac72e003ebc63a7183c7739b098dea Mon Sep 17 00:00:00 2001 From: XuZhi Date: Mon, 11 Apr 2022 19:20:50 +0800 Subject: [PATCH 07/14] do not offload concat default --- python/tvm/relay/op/contrib/arm_compute_lib.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/op/contrib/arm_compute_lib.py b/python/tvm/relay/op/contrib/arm_compute_lib.py index 7c7a10e0c849..ac1fe28a4d99 100644 --- a/python/tvm/relay/op/contrib/arm_compute_lib.py +++ b/python/tvm/relay/op/contrib/arm_compute_lib.py @@ -28,6 +28,10 @@ from .register import register_pattern_table +# global variable control wether offload concatenate +offload_concat_ = False + + def is_arm_compute_runtime_enabled(): """Check if the ACL graph executor is present. @@ -42,7 +46,7 @@ def is_arm_compute_runtime_enabled(): return False -def partition_for_arm_compute_lib(mod, params=None, **opts): +def partition_for_arm_compute_lib(mod, params=None, offload_concat=False, **opts): """Partition the graph greedily offloading supported operators to Arm Compute Library. @@ -52,11 +56,15 @@ def partition_for_arm_compute_lib(mod, params=None, **opts): The module to run passes on. params : Optional[Dict[str, NDArray]] Constant input parameters. + offload_concat : Optional[bool] + Whether offload concatenate Returns ------- ret : annotated and partitioned module. """ + global offload_concat_ + offload_concat_ = offload_concat if params: mod["main"] = bind_params_by_name(mod["main"], params) @@ -493,6 +501,8 @@ def qnn_add(expr): @tvm.ir.register_op_attr("concatenate", "target.arm_compute_lib") def concatenate(expr): """Check if the external ACL codegen for concatenate should be used.""" + if not offload_concat_: + return False attrs, type_args = expr.attrs, expr.type_args for idx in range(len(type_args[0].fields)): if type_args[0].fields[idx].dtype not in ["float32", "uint8"]: From 6c03b19263a463b2640c244c50db42f18159eae3 Mon Sep 17 00:00:00 2001 From: XuZhi Date: Mon, 11 Apr 2022 21:18:58 +0800 Subject: [PATCH 08/14] fix concattnate test failure --- tests/python/contrib/test_arm_compute_lib/infrastructure.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/contrib/test_arm_compute_lib/infrastructure.py b/tests/python/contrib/test_arm_compute_lib/infrastructure.py index e582874d1de2..0e67819fa331 100644 --- a/tests/python/contrib/test_arm_compute_lib/infrastructure.py +++ b/tests/python/contrib/test_arm_compute_lib/infrastructure.py @@ -169,7 +169,7 @@ def build_module(mod, target, params=None, enable_acl=True, tvm_ops=0, acl_parti mod = tvm.IRModule.from_expr(mod) with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): if enable_acl: - mod = arm_compute_lib.partition_for_arm_compute_lib(mod, params) + mod = arm_compute_lib.partition_for_arm_compute_lib(mod, params, offload_concat=True) tvm_op_count = get_cpu_op_count(mod) assert tvm_op_count == tvm_ops, "Got {} TVM operators, expected {}".format( tvm_op_count, tvm_ops From 9d3631875c3861f0f2bd8a1cf9d938c88fff166c Mon Sep 17 00:00:00 2001 From: XuZhi Date: Mon, 11 Apr 2022 23:44:03 +0800 Subject: [PATCH 09/14] fix test failure --- .../contrib/test_arm_compute_lib/infrastructure.py | 10 ++++++---- .../contrib/test_arm_compute_lib/test_concatenate.py | 4 ++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/python/contrib/test_arm_compute_lib/infrastructure.py b/tests/python/contrib/test_arm_compute_lib/infrastructure.py index 0e67819fa331..bb75d05076b7 100644 --- a/tests/python/contrib/test_arm_compute_lib/infrastructure.py +++ b/tests/python/contrib/test_arm_compute_lib/infrastructure.py @@ -163,13 +163,13 @@ def skip_codegen_test(): return True -def build_module(mod, target, params=None, enable_acl=True, tvm_ops=0, acl_partitions=1): +def build_module(mod, target, params=None, enable_acl=True, tvm_ops=0, acl_partitions=1, offload_concat=False): """Build module with option to build for ACL.""" if isinstance(mod, tvm.relay.expr.Call): mod = tvm.IRModule.from_expr(mod) with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): if enable_acl: - mod = arm_compute_lib.partition_for_arm_compute_lib(mod, params, offload_concat=True) + mod = arm_compute_lib.partition_for_arm_compute_lib(mod, params, offload_concat=offload_concat) tvm_op_count = get_cpu_op_count(mod) assert tvm_op_count == tvm_ops, "Got {} TVM operators, expected {}".format( tvm_op_count, tvm_ops @@ -199,13 +199,14 @@ def build_and_run( tvm_ops=0, acl_partitions=1, config=None, + offload_concat=False ): """Build and run the relay module.""" if config is None: config = {} try: - lib = build_module(mod, device.target, params, enable_acl, tvm_ops, acl_partitions) + lib = build_module(mod, device.target, params, enable_acl, tvm_ops, acl_partitions, offload_concat) except Exception as e: err_msg = "The module could not be built.\n" if config: @@ -276,9 +277,10 @@ def verify_codegen( num_acl_modules=1, tvm_ops=0, target="llvm -mtriple=aarch64-linux-gnu -mattr=+neon", + offload_concat=False, ): """Check acl codegen against a known good output.""" - module = build_module(module, target, tvm_ops=tvm_ops, acl_partitions=num_acl_modules) + module = build_module(module, target, tvm_ops=tvm_ops, acl_partitions=num_acl_modules, offload_concat=offload_concat) acl_modules = extract_acl_modules(module) assert len(acl_modules) == num_acl_modules, ( diff --git a/tests/python/contrib/test_arm_compute_lib/test_concatenate.py b/tests/python/contrib/test_arm_compute_lib/test_concatenate.py index 86adf6e0c33e..42bbd745c18d 100644 --- a/tests/python/contrib/test_arm_compute_lib/test_concatenate.py +++ b/tests/python/contrib/test_arm_compute_lib/test_concatenate.py @@ -100,7 +100,7 @@ def test_concatenate(): inputs["a"].shape, inputs["b"].shape, inputs["c"].shape, axis, dtype, iter(inputs) ) for acl in [False, True]: - outputs.append(build_and_run(func, inputs, 1, None, device, enable_acl=acl)[0]) + outputs.append(build_and_run(func, inputs, 1, None, device, enable_acl=acl, offload_concat=True)[0]) config = { "input_shape_a": input_shape_a, @@ -124,7 +124,7 @@ def test_codegen_concatenate(): args = (shape_a, shape_b, shape_c, axis, dtype) func = _get_model(*args, iter(inputs)) exp_codegen = _get_expected_codegen(*args) - verify_codegen(func, exp_codegen, 1) + verify_codegen(func, exp_codegen, 1, offload_concat=True) if __name__ == "__main__": From ca28694d73a1d78a0ab9a36bab0283b5e389718c Mon Sep 17 00:00:00 2001 From: XuZhi Date: Tue, 12 Apr 2022 00:18:17 +0800 Subject: [PATCH 10/14] fix lint error --- .../test_arm_compute_lib/infrastructure.py | 22 ++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/tests/python/contrib/test_arm_compute_lib/infrastructure.py b/tests/python/contrib/test_arm_compute_lib/infrastructure.py index bb75d05076b7..92d2bfdc43d5 100644 --- a/tests/python/contrib/test_arm_compute_lib/infrastructure.py +++ b/tests/python/contrib/test_arm_compute_lib/infrastructure.py @@ -163,13 +163,17 @@ def skip_codegen_test(): return True -def build_module(mod, target, params=None, enable_acl=True, tvm_ops=0, acl_partitions=1, offload_concat=False): +def build_module( + mod, target, params=None, enable_acl=True, tvm_ops=0, acl_partitions=1, offload_concat=False +): """Build module with option to build for ACL.""" if isinstance(mod, tvm.relay.expr.Call): mod = tvm.IRModule.from_expr(mod) with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): if enable_acl: - mod = arm_compute_lib.partition_for_arm_compute_lib(mod, params, offload_concat=offload_concat) + mod = arm_compute_lib.partition_for_arm_compute_lib( + mod, params, offload_concat=offload_concat + ) tvm_op_count = get_cpu_op_count(mod) assert tvm_op_count == tvm_ops, "Got {} TVM operators, expected {}".format( tvm_op_count, tvm_ops @@ -199,14 +203,16 @@ def build_and_run( tvm_ops=0, acl_partitions=1, config=None, - offload_concat=False + offload_concat=False, ): """Build and run the relay module.""" if config is None: config = {} try: - lib = build_module(mod, device.target, params, enable_acl, tvm_ops, acl_partitions, offload_concat) + lib = build_module( + mod, device.target, params, enable_acl, tvm_ops, acl_partitions, offload_concat + ) except Exception as e: err_msg = "The module could not be built.\n" if config: @@ -280,7 +286,13 @@ def verify_codegen( offload_concat=False, ): """Check acl codegen against a known good output.""" - module = build_module(module, target, tvm_ops=tvm_ops, acl_partitions=num_acl_modules, offload_concat=offload_concat) + module = build_module( + module, + target, + tvm_ops=tvm_ops, + acl_partitions=num_acl_modules, + offload_concat=offload_concat, + ) acl_modules = extract_acl_modules(module) assert len(acl_modules) == num_acl_modules, ( From 9bb6421c0aee6fecc57e0ce85697ae3928d741c5 Mon Sep 17 00:00:00 2001 From: XuZhi Date: Tue, 12 Apr 2022 09:52:17 +0800 Subject: [PATCH 11/14] fix lint --- tests/python/contrib/test_arm_compute_lib/test_concatenate.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/python/contrib/test_arm_compute_lib/test_concatenate.py b/tests/python/contrib/test_arm_compute_lib/test_concatenate.py index 42bbd745c18d..ecf239ff3b92 100644 --- a/tests/python/contrib/test_arm_compute_lib/test_concatenate.py +++ b/tests/python/contrib/test_arm_compute_lib/test_concatenate.py @@ -100,7 +100,9 @@ def test_concatenate(): inputs["a"].shape, inputs["b"].shape, inputs["c"].shape, axis, dtype, iter(inputs) ) for acl in [False, True]: - outputs.append(build_and_run(func, inputs, 1, None, device, enable_acl=acl, offload_concat=True)[0]) + outputs.append( + build_and_run(func, inputs, 1, None, device, enable_acl=acl, offload_concat=True)[0] + ) config = { "input_shape_a": input_shape_a, From 56eb71466d57902b57040db4995ab9fb8350add8 Mon Sep 17 00:00:00 2001 From: XuZhi Date: Wed, 13 Apr 2022 17:18:34 +0800 Subject: [PATCH 12/14] remove global var offload_concat --- .../tvm/relay/op/contrib/arm_compute_lib.py | 39 ++++++++----------- 1 file changed, 16 insertions(+), 23 deletions(-) diff --git a/python/tvm/relay/op/contrib/arm_compute_lib.py b/python/tvm/relay/op/contrib/arm_compute_lib.py index ac1fe28a4d99..c3e146b1a6c4 100644 --- a/python/tvm/relay/op/contrib/arm_compute_lib.py +++ b/python/tvm/relay/op/contrib/arm_compute_lib.py @@ -28,10 +28,6 @@ from .register import register_pattern_table -# global variable control wether offload concatenate -offload_concat_ = False - - def is_arm_compute_runtime_enabled(): """Check if the ACL graph executor is present. @@ -63,15 +59,13 @@ def partition_for_arm_compute_lib(mod, params=None, offload_concat=False, **opts ------- ret : annotated and partitioned module. """ - global offload_concat_ - offload_concat_ = offload_concat if params: mod["main"] = bind_params_by_name(mod["main"], params) seq = tvm.transform.Sequential( [ transform.InferType(), - transform.MergeComposite(arm_compute_lib_pattern_table()), + transform.MergeComposite(arm_compute_lib_pattern_table(offload_concat)), transform.AnnotateTarget("arm_compute_lib", False), transform.PartitionGraph(), ] @@ -136,7 +130,7 @@ def convert_conv(attrs, inputs, tinfos, desired_layouts): @register_pattern_table("arm_compute_lib") -def arm_compute_lib_pattern_table(): +def arm_compute_lib_pattern_table(offload_concat=False): """Get the ACL pattern table.""" def conv_pattern(): @@ -274,6 +268,20 @@ def check_l2_pool2d(extract): pool = extract.args[0] return avg_pool2d(pool) + if offload_concat and not tvm.ir.Op.get("concatenate").get_attr("target.arm_compute_lib"): + + @tvm.ir.register_op_attr("concatenate", "target.arm_compute_lib") + def concatenate(expr): + """Check if the external ACL codegen for concatenate should be used.""" + attrs, type_args = expr.attrs, expr.type_args + for idx in range(len(type_args[0].fields)): + if type_args[0].fields[idx].dtype not in ["float32", "uint8"]: + return False + # ACL concatenate only supports maximum 4 dimensions input tensor + if attrs.axis not in [-4, -3, -2, -1, 0, 1, 2, 3]: + return False + return True + return [ ("arm_compute_lib.conv2d", conv_pattern(), check_conv), ("arm_compute_lib.qnn_conv2d", qnn_conv_pattern(), check_qnn_conv), @@ -498,21 +506,6 @@ def qnn_add(expr): return True -@tvm.ir.register_op_attr("concatenate", "target.arm_compute_lib") -def concatenate(expr): - """Check if the external ACL codegen for concatenate should be used.""" - if not offload_concat_: - return False - attrs, type_args = expr.attrs, expr.type_args - for idx in range(len(type_args[0].fields)): - if type_args[0].fields[idx].dtype not in ["float32", "uint8"]: - return False - # ACL concatenate only supports maximum 4 dimensions input tensor - if attrs.axis not in [-4, -3, -2, -1, 0, 1, 2, 3]: - return False - return True - - class OpAttrContext(object): """Temporarily changes the attr of an op.""" From b219357e9d1b79c8d1675b358addc478bae9f41f Mon Sep 17 00:00:00 2001 From: XuZhi Date: Wed, 13 Apr 2022 22:02:54 +0800 Subject: [PATCH 13/14] support concatenate with pattern table mechanism --- .../tvm/relay/op/contrib/arm_compute_lib.py | 47 ++++++++++++------- .../contrib/arm_compute_lib/codegen.cc | 26 ++++++++++ .../test_arm_compute_lib/infrastructure.py | 18 ++++--- .../test_arm_compute_lib/test_concatenate.py | 33 +++++++++---- 4 files changed, 92 insertions(+), 32 deletions(-) diff --git a/python/tvm/relay/op/contrib/arm_compute_lib.py b/python/tvm/relay/op/contrib/arm_compute_lib.py index c3e146b1a6c4..bcf1ee0b6732 100644 --- a/python/tvm/relay/op/contrib/arm_compute_lib.py +++ b/python/tvm/relay/op/contrib/arm_compute_lib.py @@ -23,7 +23,7 @@ from tvm.relay.build_module import bind_params_by_name from tvm.relay.expr import const -from ...dataflow_pattern import is_constant, is_expr, is_op, wildcard +from ...dataflow_pattern import is_constant, is_expr, is_op, is_tuple, wildcard from ..strategy.generic import is_depthwise_conv2d from .register import register_pattern_table @@ -42,7 +42,7 @@ def is_arm_compute_runtime_enabled(): return False -def partition_for_arm_compute_lib(mod, params=None, offload_concat=False, **opts): +def partition_for_arm_compute_lib(mod, params=None, disabled_ops=["concatenate"], **opts): """Partition the graph greedily offloading supported operators to Arm Compute Library. @@ -52,8 +52,8 @@ def partition_for_arm_compute_lib(mod, params=None, offload_concat=False, **opts The module to run passes on. params : Optional[Dict[str, NDArray]] Constant input parameters. - offload_concat : Optional[bool] - Whether offload concatenate + disabled_ops : Optional[list] + Ops do not want to offload to ACL. Returns ------- @@ -65,7 +65,7 @@ def partition_for_arm_compute_lib(mod, params=None, offload_concat=False, **opts seq = tvm.transform.Sequential( [ transform.InferType(), - transform.MergeComposite(arm_compute_lib_pattern_table(offload_concat)), + transform.MergeComposite(arm_compute_lib_pattern_table(disabled_ops)), transform.AnnotateTarget("arm_compute_lib", False), transform.PartitionGraph(), ] @@ -130,7 +130,7 @@ def convert_conv(attrs, inputs, tinfos, desired_layouts): @register_pattern_table("arm_compute_lib") -def arm_compute_lib_pattern_table(offload_concat=False): +def arm_compute_lib_pattern_table(disabled_ops=["concatenate"]): """Get the ACL pattern table.""" def conv_pattern(): @@ -222,6 +222,17 @@ def l2_pool2d_pattern(): pattern = is_op("sqrt")(pattern) return pattern + def concatenate_pattern(): + """Create an concatenate pattern from equivalent relay operators. + + Returns + ------- + pattern : dataflow_pattern.AltPattern + Denotes the concatenate pattern. + """ + pattern = is_op("concatenate")(is_tuple(None)) + return pattern + def check_conv(extract): """Check conv pattern is supported by ACL.""" call = extract @@ -268,19 +279,18 @@ def check_l2_pool2d(extract): pool = extract.args[0] return avg_pool2d(pool) - if offload_concat and not tvm.ir.Op.get("concatenate").get_attr("target.arm_compute_lib"): - - @tvm.ir.register_op_attr("concatenate", "target.arm_compute_lib") - def concatenate(expr): - """Check if the external ACL codegen for concatenate should be used.""" - attrs, type_args = expr.attrs, expr.type_args - for idx in range(len(type_args[0].fields)): - if type_args[0].fields[idx].dtype not in ["float32", "uint8"]: - return False - # ACL concatenate only supports maximum 4 dimensions input tensor - if attrs.axis not in [-4, -3, -2, -1, 0, 1, 2, 3]: + def check_concatenate(expr): + """Check concatenate pattern is supported by ACL.""" + if "concatenate" in disabled_ops: + return False + attrs, type_args = expr.attrs, expr.type_args + for idx in range(len(type_args[0].fields)): + if type_args[0].fields[idx].dtype not in ["float32", "uint8"]: return False - return True + # ACL concatenate only supports maximum 4 dimensions input tensor + if attrs.axis not in [-4, -3, -2, -1, 0, 1, 2, 3]: + return False + return True return [ ("arm_compute_lib.conv2d", conv_pattern(), check_conv), @@ -290,6 +300,7 @@ def concatenate(expr): ("arm_compute_lib.qnn_conv2d", qnn_conv_pattern(), check_qnn_conv), ("arm_compute_lib.avg_pool2d", avg_pool2d_pattern(), check_avg_pool2d), ("arm_compute_lib.l2_pool2d", l2_pool2d_pattern(), check_l2_pool2d), + ("arm_compute_lib.concatenate", concatenate_pattern(), check_concatenate), ] diff --git a/src/relay/backend/contrib/arm_compute_lib/codegen.cc b/src/relay/backend/contrib/arm_compute_lib/codegen.cc index 8098c8d51274..842ede3bf20b 100644 --- a/src/relay/backend/contrib/arm_compute_lib/codegen.cc +++ b/src/relay/backend/contrib/arm_compute_lib/codegen.cc @@ -99,6 +99,8 @@ class ACLJSONSerializer : public backend::contrib::JSONSerializer { json_node = CreateCompositeAvgPool2DJSONNode(cn); } else if (name == "arm_compute_lib.l2_pool2d") { json_node = CreateCompositeL2Pool2DJSONNode(cn); + } else if (name == "arm_compute_lib.concatenate") { + return AddCommonSingleJSONNode(cn, "concatenate"); } else { LOG(FATAL) << "Unrecognized Arm Compute Library pattern: " << name; } @@ -342,6 +344,30 @@ class ACLJSONSerializer : public backend::contrib::JSONSerializer { SetCallNodeAttribute(json_node, avg_pool); return json_node; } + + /*! + * \brief Create a JSON representation of a single operator. + * \param cn The call to be represented. + * \param name The name of the operator. + * \return A list of graph entry nodes. + */ + std::vector AddCommonSingleJSONNode(const CallNode* cn, std::string name) { + std::vector inputs; + for (const auto& arg : cn->args) { + auto res = VisitExpr(arg); + inputs.insert(inputs.end(), res.begin(), res.end()); + } + auto node = std::make_shared(name, /* name_ */ + "kernel", /* op_type_ */ + inputs, 1 /* num_outputs_ */); + + const auto* fn = cn->op.as(); + ICHECK(fn); + const auto* callNode = fn->body.as(); + ICHECK(callNode); + SetCallNodeAttribute(node, callNode); + return AddNode(node, GetRef(cn)); + } }; /*! diff --git a/tests/python/contrib/test_arm_compute_lib/infrastructure.py b/tests/python/contrib/test_arm_compute_lib/infrastructure.py index 92d2bfdc43d5..314da972c049 100644 --- a/tests/python/contrib/test_arm_compute_lib/infrastructure.py +++ b/tests/python/contrib/test_arm_compute_lib/infrastructure.py @@ -164,7 +164,13 @@ def skip_codegen_test(): def build_module( - mod, target, params=None, enable_acl=True, tvm_ops=0, acl_partitions=1, offload_concat=False + mod, + target, + params=None, + enable_acl=True, + tvm_ops=0, + acl_partitions=1, + disabled_ops=["concatenate"], ): """Build module with option to build for ACL.""" if isinstance(mod, tvm.relay.expr.Call): @@ -172,7 +178,7 @@ def build_module( with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): if enable_acl: mod = arm_compute_lib.partition_for_arm_compute_lib( - mod, params, offload_concat=offload_concat + mod, params, disabled_ops=disabled_ops ) tvm_op_count = get_cpu_op_count(mod) assert tvm_op_count == tvm_ops, "Got {} TVM operators, expected {}".format( @@ -203,7 +209,7 @@ def build_and_run( tvm_ops=0, acl_partitions=1, config=None, - offload_concat=False, + disabled_ops=["concatenate"], ): """Build and run the relay module.""" if config is None: @@ -211,7 +217,7 @@ def build_and_run( try: lib = build_module( - mod, device.target, params, enable_acl, tvm_ops, acl_partitions, offload_concat + mod, device.target, params, enable_acl, tvm_ops, acl_partitions, disabled_ops ) except Exception as e: err_msg = "The module could not be built.\n" @@ -283,7 +289,7 @@ def verify_codegen( num_acl_modules=1, tvm_ops=0, target="llvm -mtriple=aarch64-linux-gnu -mattr=+neon", - offload_concat=False, + disabled_ops=["concatenate"], ): """Check acl codegen against a known good output.""" module = build_module( @@ -291,7 +297,7 @@ def verify_codegen( target, tvm_ops=tvm_ops, acl_partitions=num_acl_modules, - offload_concat=offload_concat, + disabled_ops=disabled_ops, ) acl_modules = extract_acl_modules(module) diff --git a/tests/python/contrib/test_arm_compute_lib/test_concatenate.py b/tests/python/contrib/test_arm_compute_lib/test_concatenate.py index ecf239ff3b92..deba26a0db56 100644 --- a/tests/python/contrib/test_arm_compute_lib/test_concatenate.py +++ b/tests/python/contrib/test_arm_compute_lib/test_concatenate.py @@ -47,8 +47,8 @@ def _get_expected_codegen(input_shape_a, input_shape_b, input_shape_c, axis, dty "name": "concatenate", "inputs": [ [0, 0, 0], - [0, 1, 0], - [0, 2, 0], + [1, 0, 0], + [2, 0, 0], ], "attrs": { "num_outputs": "1", @@ -59,16 +59,33 @@ def _get_expected_codegen(input_shape_a, input_shape_b, input_shape_c, axis, dty }, } - input = { + input_a = { "op": "input", "name": "", "attrs": { - "shape": [[input_shape_a, input_shape_b, input_shape_c]], - "dtype": [[dtype, dtype, dtype]], + "shape": [[input_shape_a]], + "dtype": [[dtype]], + }, + } + + input_b = { + "op": "input", + "name": "", + "attrs": { + "shape": [[input_shape_b]], + "dtype": [[dtype]], }, } - return [input, node] + input_c = { + "op": "input", + "name": "", + "attrs": { + "shape": [[input_shape_c]], + "dtype": [[dtype]], + }, + } + return [input_a, input_b, input_c, node] def test_concatenate(): @@ -101,7 +118,7 @@ def test_concatenate(): ) for acl in [False, True]: outputs.append( - build_and_run(func, inputs, 1, None, device, enable_acl=acl, offload_concat=True)[0] + build_and_run(func, inputs, 1, None, device, enable_acl=acl, disabled_ops=[])[0] ) config = { @@ -126,7 +143,7 @@ def test_codegen_concatenate(): args = (shape_a, shape_b, shape_c, axis, dtype) func = _get_model(*args, iter(inputs)) exp_codegen = _get_expected_codegen(*args) - verify_codegen(func, exp_codegen, 1, offload_concat=True) + verify_codegen(func, exp_codegen, 1, disabled_ops=[]) if __name__ == "__main__": From 4603c1289a406d5b787dc1ee6807ccc199c5b1d0 Mon Sep 17 00:00:00 2001 From: XuZhi Date: Thu, 14 Apr 2022 10:08:51 +0800 Subject: [PATCH 14/14] disable pylint dangerous-default-value warning --- python/tvm/relay/op/contrib/arm_compute_lib.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/op/contrib/arm_compute_lib.py b/python/tvm/relay/op/contrib/arm_compute_lib.py index bcf1ee0b6732..9abd320b2956 100644 --- a/python/tvm/relay/op/contrib/arm_compute_lib.py +++ b/python/tvm/relay/op/contrib/arm_compute_lib.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name, unused-argument +# pylint: disable=invalid-name, unused-argument, dangerous-default-value """Arm Compute Library supported operators.""" import tvm from tvm import relay