Skip to content
1 change: 1 addition & 0 deletions python/tvm/relay/op/contrib/arm_compute_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ def _func_wrapper(expr):


_register_external_op_helper("reshape")
_register_external_op_helper("concatenate")


@tvm.ir.register_op_attr("nn.conv2d", "target.arm_compute_lib")
Expand Down
67 changes: 60 additions & 7 deletions src/runtime/contrib/arm_compute_lib/acl_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#ifdef TVM_GRAPH_EXECUTOR_ARM_COMPUTE_LIB
#include <arm_compute/core/Types.h>
#include <arm_compute/runtime/NEON/functions/NEArithmeticAddition.h>
#include <arm_compute/runtime/NEON/functions/NEConcatenateLayer.h>
#include <arm_compute/runtime/NEON/functions/NEConvolutionLayer.h>
#include <arm_compute/runtime/NEON/functions/NEDepthwiseConvolutionLayer.h>
#include <arm_compute/runtime/NEON/functions/NEElementwiseOperations.h>
Expand Down Expand Up @@ -93,10 +94,19 @@ class ACLRuntime : public JSONRuntimeBase {
void Run() override {
for (size_t i = 0; i < input_nodes_.size(); ++i) {
auto nid = input_nodes_[i];
uint32_t eid = EntryID(nid, 0);
if (nodes_[nid].GetOpType() == "input") {
void* data = data_entry_[eid]->data;
CheckACLError(layer_.inputs[i].allocator()->import_memory(data));
for (uint32_t index = 0; index < nodes_[nid].GetNumOutput(); index++) {
uint32_t eid = EntryID(nid, index);
void* data = data_entry_[eid]->data;
auto key = std::pair<uint32_t, uint32_t>(nid, index);
if (layer_.json_inputid_to_layer_inputid.count(key) > 0) {
CheckACLError(
layer_.inputs[layer_.json_inputid_to_layer_inputid[key]].allocator()->import_memory(
data));
} else {
CheckACLError(layer_.inputs[i].allocator()->import_memory(data));
}
}
}
}

Expand Down Expand Up @@ -149,6 +159,8 @@ class ACLRuntime : public JSONRuntimeBase {
CreateMaximumLayer(&layer_, node);
} else if ("add" == op_name || "qnn.add" == op_name) {
CreateAddLayer(&layer_, node);
} else if ("concatenate" == op_name) {
CreateConcatenateLayer(&layer_, node);
} else {
LOG(FATAL) << "Unsupported op: " << op_name;
}
Expand All @@ -166,6 +178,7 @@ class ACLRuntime : public JSONRuntimeBase {
std::shared_ptr<arm_compute::IFunction> function;
std::vector<arm_compute::Tensor> inputs;
std::vector<arm_compute::Tensor> outputs;
std::map<std::pair<uint32_t, uint32_t>, uint32_t> json_inputid_to_layer_inputid;
};

/*!
Expand All @@ -175,17 +188,25 @@ class ACLRuntime : public JSONRuntimeBase {
* \param tensor The tensor to represent.
* \param scale (optional) The scale of the tensor as an input.
* \param offset (optional) The offset of the tensor as an input.
* \param apply_dim_correction (Optional) Flag to state whether apply dimension correction after
* setting one dimension. E.g. when permuting NCHW -> NHWC, 1x1x2 would become 2x1x1, but
* _num_dimensions should be 3 rather than 1.
* \param increase_dim_unit (Optional) Set to true if new unit dimensions increase the number of
* dimensions of the shape.
* \return ACL Tensor.
*/
arm_compute::Tensor MakeACLTensorFromJSONEntry(const JSONGraphNodeEntry& tensor,
JSONGraphNodeEntry* scale = nullptr,
JSONGraphNodeEntry* offset = nullptr) {
JSONGraphNodeEntry* offset = nullptr,
bool apply_dim_correction = true,
bool increase_dim_unit = true) {
JSONGraphNode node = nodes_[tensor.id_];
void* node_data = nullptr;
if (node.GetOpType() == "const") {
node_data = data_entry_[EntryID(tensor)]->data;
}
return MakeACLTensorFromJSONNode(node, scale, offset, node_data);
return MakeACLTensorFromJSONNode(node, scale, offset, node_data, apply_dim_correction,
increase_dim_unit);
}

/*!
Expand All @@ -196,19 +217,27 @@ class ACLRuntime : public JSONRuntimeBase {
* \param scale (optional) The scale of the tensor as an input.
* \param offset (optional) The offset of the tensor as an input.
* \param data (optional) Constant data of input node.
* \param apply_dim_correction (Optional) Flag to state whether apply dimension correction after
* setting one dimension. E.g. when permuting NCHW -> NHWC, 1x1x2 would become 2x1x1, but
* _num_dimensions should be 3 rather than 1.
* \param increase_dim_unit (Optional) Set to true if new unit dimensions increase the number of
* dimensions of the shape.
* \return ACL Tensor.
*/
arm_compute::Tensor MakeACLTensorFromJSONNode(const JSONGraphNode& node,
JSONGraphNodeEntry* scale = nullptr,
JSONGraphNodeEntry* offset = nullptr,
void* data = nullptr) {
void* data = nullptr,
bool apply_dim_correction = true,
bool increase_dim_unit = true) {
const DLTensor* scale_data = nullptr;
const DLTensor* offset_data = nullptr;
if (scale && offset) {
scale_data = data_entry_[EntryID(*scale)];
offset_data = data_entry_[EntryID(*offset)];
}
return MakeACLTensor(node, data, scale_data, offset_data);
return MakeACLTensor(node, data, scale_data, offset_data, apply_dim_correction,
increase_dim_unit);
}

/*!
Expand Down Expand Up @@ -510,6 +539,30 @@ class ACLRuntime : public JSONRuntimeBase {
layer->function = f;
}

/*!
* \brief Create a Concatenate layer.
*
* \param layer The ACL layer to build. Containing inputs, outputs and the ACL function.c
* \param node The JSON representation of the operator.
*/
void CreateConcatenateLayer(CachedLayer* layer, const JSONGraphNode& node) {
std::vector<std::string> axis = node.GetAttr<std::vector<std::string>>("axis");
std::vector<const arm_compute::ITensor*> inputs;
for (auto input : node.GetInputs()) {
layer->inputs.push_back(MakeACLTensorFromJSONEntry(input, nullptr, nullptr, false));
layer->json_inputid_to_layer_inputid[std::pair<uint32_t, uint32_t>(input.id_, input.index_)] =
layer->inputs.size() - 1;
}
for (size_t i = 0; i < layer->inputs.size(); i++) {
inputs.push_back(&layer->inputs[i]);
}
layer->outputs.push_back(MakeACLTensorFromJSONNode(node));
int dimNum = layer->inputs[0].info()->num_dimensions();
auto function = std::make_shared<arm_compute::NEConcatenateLayer>();
function->configure(inputs, &layer->outputs[0], dimNum - std::stoi(axis[0]) - 1);
layer->function = function;
}

/*! \brief Allow ACL functions to request auxiliary memory from TVM. */
ACLAllocator allocator_;
/*!
Expand Down
11 changes: 7 additions & 4 deletions src/runtime/contrib/arm_compute_lib/acl_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,13 @@ void CheckACLError(const arm_compute::Status& status) {
}

arm_compute::Tensor MakeACLTensor(const JSONGraphNode& tensor_rep, void* data,
const DLTensor* scale, const DLTensor* offset) {
const DLTensor* scale, const DLTensor* offset,
bool apply_dim_correction, bool increase_dim_unit) {
arm_compute::Tensor tensor;
std::vector<int64_t> shape = tensor_rep.GetOpShape()[0];
DLDataType dtype = tensor_rep.GetOpDataType()[0];
arm_compute::TensorInfo info = MakeACLTensorInfo(shape, dtype, scale, offset);
arm_compute::TensorInfo info =
MakeACLTensorInfo(shape, dtype, scale, offset, apply_dim_correction, increase_dim_unit);
info.set_is_resizable(false);
tensor.allocator()->init(info);
if (data != nullptr) {
Expand All @@ -55,10 +57,11 @@ arm_compute::Tensor MakeACLTensor(const JSONGraphNode& tensor_rep, void* data,

arm_compute::TensorInfo MakeACLTensorInfo(const std::vector<int64_t>& shape,
const DLDataType& dtype, const DLTensor* scale,
const DLTensor* offset) {
const DLTensor* offset, bool apply_dim_correction,
bool increase_dim_unit) {
arm_compute::TensorShape acl_shape;
for (unsigned int i = shape.size(); i > 0; --i) {
acl_shape.set(shape.size() - i, shape[i - 1]);
acl_shape.set(shape.size() - i, shape[i - 1], apply_dim_correction, increase_dim_unit);
}
arm_compute::DataType acl_dtype = MakeACLDataType(dtype);
arm_compute::TensorInfo info(acl_shape, 1, acl_dtype, arm_compute::DataLayout::NHWC);
Expand Down
8 changes: 5 additions & 3 deletions src/runtime/contrib/arm_compute_lib/acl_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ void CheckACLError(const arm_compute::Status& status);
* \return arm_compute::Tensor.
*/
arm_compute::Tensor MakeACLTensor(const JSONGraphNode& tensor_rep, void* data = nullptr,
const DLTensor* scale = nullptr,
const DLTensor* offset = nullptr);
const DLTensor* scale = nullptr, const DLTensor* offset = nullptr,
bool apply_dim_correction = true, bool increase_dim_unit = true);

/*!
* \brief Make an acl tensor info object from JSON tensor
Expand All @@ -78,7 +78,9 @@ arm_compute::Tensor MakeACLTensor(const JSONGraphNode& tensor_rep, void* data =
*/
arm_compute::TensorInfo MakeACLTensorInfo(const std::vector<int64_t>& shape,
const DLDataType& dtype, const DLTensor* scale = nullptr,
const DLTensor* offset = nullptr);
const DLTensor* offset = nullptr,
bool apply_dim_correction = true,
bool increase_dim_unit = true);

/*!
* \brief Create a memory manager for use with a layer that
Expand Down
1 change: 1 addition & 0 deletions src/runtime/contrib/json/json_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ class JSONRuntimeBase : public ModuleNode {
for (size_t j = 0; j < nodes_[nid].GetOpShape().size(); ++j) {
input_var_eid_.push_back(EntryID(nid, j));
}
nodes_[nid].SetNumOutput(nodes_[nid].GetOpShape().size());
} else {
ICHECK_EQ(nodes_[nid].op_type_, "const");
auto pos = std::find(std::begin(const_names_), std::end(const_names_), name);
Expand Down
126 changes: 126 additions & 0 deletions tests/python/contrib/test_arm_compute_lib/test_concatenate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Arm Compute Library integration space_to_batch_nd tests."""

import numpy as np

import tvm
from tvm import relay
from tvm import testing

from test_arm_compute_lib.infrastructure import (
skip_runtime_test,
skip_codegen_test,
build_and_run,
verify,
verify_codegen,
)
from test_arm_compute_lib.infrastructure import Device


def _get_model(input_shape_a, input_shape_b, input_shape_c, axis, dtype, var_names):
"""Return a model and any parameters it may have."""
a = relay.var(next(var_names), shape=input_shape_a, dtype=dtype)
b = relay.var(next(var_names), shape=input_shape_b, dtype=dtype)
c = relay.var(next(var_names), shape=input_shape_c, dtype=dtype)
out = relay.concatenate([a, b, c], axis)
return out


def _get_expected_codegen(input_shape_a, input_shape_b, input_shape_c, axis, dtype):
node = {
"op": "kernel",
"name": "concatenate",
"inputs": [
[0, 0, 0],
[0, 1, 0],
[0, 2, 0],
],
"attrs": {
"num_outputs": "1",
"num_inputs": "3",
"dtype": [[dtype]],
"axis": [[str(axis)]],
"shape": [[[3, 234, 234, 256]]],
},
}

input = {
"op": "input",
"name": "",
"attrs": {
"shape": [[input_shape_a, input_shape_b, input_shape_c]],
"dtype": [[dtype, dtype, dtype]],
},
}

return [input, node]


def test_concatenate():
Device.load("test_config.json")

if skip_runtime_test():
return

device = Device()
np.random.seed(0)

for input_shape_a, input_shape_b, input_shape_c, axis in [
([1, 234, 234, 256], [1, 234, 234, 256], [1, 234, 234, 256], 0),
]:
dtype = "int32"
outputs = []
inputs = {
"a": tvm.nd.array(np.random.randn(*input_shape_a).astype(dtype)),
"b": tvm.nd.array(np.random.randn(*input_shape_b).astype(dtype)),
"c": tvm.nd.array(np.random.randn(*input_shape_c).astype(dtype)),
}
func = _get_model(
inputs["a"].shape, inputs["b"].shape, inputs["c"].shape, axis, dtype, iter(inputs)
)
for acl in [False, True]:
outputs.append(build_and_run(func, inputs, 1, None, device, enable_acl=acl)[0])

config = {
"input_shape_a": input_shape_a,
"input_shape_b": input_shape_b,
"input_shape_c": input_shape_c,
"axis": 0,
"dtype": dtype,
}
verify(outputs, atol=1e-7, rtol=1e-7, config=config)


def test_codegen_concatenate():
if skip_codegen_test():
return
shape_a = [1, 234, 234, 256]
shape_b = [1, 234, 234, 256]
shape_c = [1, 234, 234, 256]
axis = 0
inputs = {"a", "b", "c"}
for dtype in ["float32"]:
args = (shape_a, shape_b, shape_c, axis, dtype)
func = _get_model(*args, iter(inputs))
exp_codegen = _get_expected_codegen(*args)
verify_codegen(func, exp_codegen, 1)


if __name__ == "__main__":
test_concatenate()
test_codegen_concatenate()