diff --git a/CMakeLists.txt b/CMakeLists.txt index 38dd59b9c906..66ea6a07da85 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -125,6 +125,8 @@ tvm_option(USE_ARM_COMPUTE_LIB "Build with Arm Compute Library" OFF) tvm_option(USE_ARM_COMPUTE_LIB_GRAPH_EXECUTOR "Build with Arm Compute Library graph executor" OFF) tvm_option(USE_TENSORRT_CODEGEN "Build with TensorRT Codegen support" OFF) tvm_option(USE_TENSORRT_RUNTIME "Build with TensorRT runtime" OFF) +tvm_option(USE_NNAPI_CODEGEN "Build with NNAPI Codegen support" OFF) +tvm_option(USE_NNAPI_RUNTIME "Build with NNAPI runtime" OFF) tvm_option(USE_RUST_EXT "Build with Rust based compiler extensions, STATIC, DYNAMIC, or OFF" OFF) tvm_option(USE_VITIS_AI "Build with VITIS-AI Codegen support" OFF) tvm_option(SUMMARIZE "Print CMake option summary after configuring" OFF) @@ -602,6 +604,7 @@ include(cmake/modules/contrib/BNNS.cmake) include(cmake/modules/contrib/ONNX.cmake) include(cmake/modules/contrib/ArmComputeLib.cmake) include(cmake/modules/contrib/TensorRT.cmake) +include(cmake/modules/contrib/NNAPI.cmake) include(cmake/modules/contrib/VitisAI.cmake) include(cmake/modules/contrib/Verilator.cmake) include(cmake/modules/contrib/UMA.cmake) diff --git a/cmake/modules/LibInfo.cmake b/cmake/modules/LibInfo.cmake index a2b51bb33195..ee6561dffce8 100644 --- a/cmake/modules/LibInfo.cmake +++ b/cmake/modules/LibInfo.cmake @@ -144,6 +144,8 @@ function(add_lib_info src_file) TVM_INFO_USE_MSC="${USE_MSC}" TVM_INFO_USE_CCACHE="${USE_CCACHE}" TVM_INFO_USE_NVSHMEM="${USE_NVSHMEM}" + TVM_INFO_USE_NNAPI_CODEGEN="${USE_NNAPI_CODEGEN}" + TVM_INFO_USE_NNAPI_RUNTIME="${USE_NNAPI_RUNTIME}" TVM_INFO_BACKTRACE_ON_SEGFAULT="${BACKTRACE_ON_SEGFAULT}" ) diff --git a/cmake/modules/contrib/NNAPI.cmake b/cmake/modules/contrib/NNAPI.cmake new file mode 100644 index 000000000000..23eb6dd11eda --- /dev/null +++ b/cmake/modules/contrib/NNAPI.cmake @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# NNAPI Codegen +if(USE_NNAPI_CODEGEN) + message(STATUS "Build with NNAPI codegen") + + tvm_file_glob(GLOB COMPILER_NNAPI_SRCS src/relax/backend/contrib/nnapi/*.cc) + tvm_file_glob(GLOB RUNTIME_NNAPI_SRCS src/runtime/contrib/nnapi/*.cc) + list(APPEND COMPILER_SRCS ${COMPILER_NNAPI_SRCS}) + if(NOT USE_NNAPI_RUNTIME) + list(APPEND COMPILER_SRCS ${RUNTIME_NNAPI_SRCS}) + endif() +endif() + +# NNAPI Runtime +if(USE_NNAPI_RUNTIME) + message(STATUS "Build with NNAPI runtime") + + tvm_file_glob(GLOB RUNTIME_NNAPI_SRCS src/runtime/contrib/nnapi/*.cc) + list(APPEND RUNTIME_SRCS ${RUNTIME_NNAPI_SRCS}) + list(APPEND TVM_RUNTIME_LINKER_LIBS neuralnetworks log) + + add_definitions(-DTVM_GRAPH_EXECUTOR_NNAPI) +endif() diff --git a/python/tvm/relax/backend/contrib/nnapi.py b/python/tvm/relax/backend/contrib/nnapi.py new file mode 100644 index 000000000000..6e428b60d584 --- /dev/null +++ b/python/tvm/relax/backend/contrib/nnapi.py @@ -0,0 +1,324 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Pattern table for NNAPI backend""" +from typing import ( + Mapping, + Optional, + Tuple, + List, +) +from tvm.ir import IRModule +from tvm.relax.transform import FuseOpsByPattern, MergeCompositeFunctions +from tvm.relax.dpl.pattern import ( + DFPattern, + wildcard, + is_op, +) + +from ..pattern_registry import get_patterns_with_prefix, register_patterns + + +def elementwise_binary_patterns() -> List[Tuple[str, DFPattern, Mapping[str, DFPattern]]]: + """ + Returns a list of tuples representing elementwise binary operation patterns mapped + between NNAPI and Relax frameworks. + """ + + def _elementwise_binary_pattern( + pattern_name: str, + op_name: str, + ) -> Tuple[str, DFPattern, Mapping[str, DFPattern]]: + input0 = wildcard() + input1 = wildcard() + + pattern = is_op(op_name)(input0, input1) + + return (pattern_name, pattern, {}) + + return [ + _elementwise_binary_pattern("nnapi.add", "relax.add"), + _elementwise_binary_pattern("nnapi.mul", "relax.multiply"), + _elementwise_binary_pattern("nnapi.div", "relax.divide"), + _elementwise_binary_pattern("nnapi.sub", "relax.subtract"), + _elementwise_binary_pattern("nnapi.pow", "relax.power"), + _elementwise_binary_pattern("nnapi.equal", "relax.equal"), + _elementwise_binary_pattern("nnapi.greater", "relax.greater"), + _elementwise_binary_pattern("nnapi.greater_equal", "relax.greater_equal"), + _elementwise_binary_pattern("nnapi.less", "relax.less"), + _elementwise_binary_pattern("nnapi.less_equal", "relax.less_equal"), + _elementwise_binary_pattern("nnapi.not_equal", "relax.not_equal"), + _elementwise_binary_pattern("nnapi.maximum", "relax.maximum"), + _elementwise_binary_pattern("nnapi.minimum", "relax.minimum"), + ] + + +def unary_patterns() -> List[Tuple[str, DFPattern, Mapping[str, DFPattern]]]: + """ + Returns a list of tuples representing unary operation patterns mapped + between NNAPI and Relax frameworks. + """ + + def _unary_pattern( + pattern_name: str, op_name: str + ) -> Tuple[str, DFPattern, Mapping[str, DFPattern]]: + input0 = wildcard() + pattern = is_op(op_name)(input0) + return (pattern_name, pattern, {}) + + return [ + _unary_pattern("nnapi.floor", "relax.floor"), + _unary_pattern("nnapi.relu", "relax.nn.relu"), + _unary_pattern("nnapi.logistic", "relax.sigmoid"), + _unary_pattern("nnapi.softmax", "relax.nn.softmax"), + _unary_pattern("nnapi.tanh", "relax.tanh"), + _unary_pattern("nnapi.abs", "relax.abs"), + _unary_pattern("nnapi.exp", "relax.exp"), + _unary_pattern("nnapi.log", "relax.log"), + _unary_pattern("nnapi.neg", "relax.negative"), + _unary_pattern("nnapi.cast", "relax.astype"), + _unary_pattern("nnapi.sqrt", "relax.sqrt"), + _unary_pattern("nnapi.rsqrt", "relax.rsqrt"), + ] + + +def matmul_pattern() -> Tuple[str, DFPattern, Mapping[str, DFPattern]]: + """ + Returns a tuple representing matmul operation patterns mapped + between NNAPI and Relax frameworks. + """ + input0 = wildcard() + input1 = wildcard() + pattern = is_op("relax.matmul")(input0, input1) + return ("nnapi.batch_matmul", pattern, {}) + + +def permute_dims_pattern() -> Tuple[str, DFPattern, Mapping[str, DFPattern]]: + """ + Returns a tuple representing permute operation patterns mapped + between NNAPI and Relax frameworks. + """ + input0 = wildcard() + pattern = is_op("relax.permute_dims")(input0) + return ("nnapi.transpose", pattern, {}) + + +def astype_pattern() -> Tuple[str, DFPattern, Mapping[str, DFPattern]]: + """ + Returns a tuple representing astype operation patterns mapped + between NNAPI and Relax frameworks. + """ + input0 = wildcard().has_dtype("float16") | wildcard().has_dtype("float32") + pattern = is_op("relax.astype")(input0).has_dtype("float16") | is_op("relax.astype")( + input0 + ).has_dtype("float32") + + return ("nnapi.cast", pattern, {}) + + +def mean_pattern() -> Tuple[str, DFPattern, Mapping[str, DFPattern]]: + """ + Returns a tuple representing mean operation patterns mapped + between NNAPI and Relax frameworks. + """ + input0 = wildcard() + pattern = is_op("relax.mean")(input0) + + return ("nnapi.mean", pattern, {}) + + +def conv2d_pattern() -> Tuple[str, DFPattern, Mapping[str, DFPattern]]: + """ + Returns a tuple representing conv2d operation patterns mapped + between NNAPI and Relax frameworks. + """ + input0 = wildcard() + input1 = wildcard() + input2 = wildcard() + conv = is_op("relax.nn.conv2d")(input0, input1) + pattern = is_op("relax.add")(conv, input2) + return ("nnapi.conv2d", pattern, {}) + + +def max_pool2d_pattern() -> Tuple[str, DFPattern, Mapping[str, DFPattern]]: + """ + Returns a tuple representing max_pool2d operation patterns mapped + between NNAPI and Relax frameworks. + """ + input0 = wildcard() + pattern = is_op("relax.nn.max_pool2d")(input0) + return ("nnapi.max_pool_2d", pattern, {}) + + +register_patterns( + [ + *elementwise_binary_patterns(), + *unary_patterns(), + matmul_pattern(), + permute_dims_pattern(), + astype_pattern(), + mean_pattern(), + conv2d_pattern(), + max_pool2d_pattern(), + ] +) + + +def min_feature_level(pattern_name: str) -> int: + """ + Returns the minimum feature level required to support a given NNAPI operation pattern. + + Args: + pattern_name (str): The name of the NNAPI operation pattern + (e.g., "nnapi.add", "nnapi.conv2d"). + + Returns: + int: The minimum feature level for the specified pattern, or 1 if the pattern is not found. + """ + + levels = { + "nnapi.add": 1, + "nnapi.average_pool_2d": 1, + "nnapi.concatenation": 1, + "nnapi.conv2d": 1, + "nnapi.depthwise_conv_2d": 1, + "nnapi.depth_to_space": 1, + "nnapi.dequantize": 1, + "nnapi.embedding_lookup": 1, + "nnapi.floor": 1, + "nnapi.fully_connected": 1, + "nnapi.hashtable_lookup": 1, + "nnapi.l2_normalization": 1, + "nnapi.l2_pool_2d": 1, + "nnapi.local_response_normalization": 1, + "nnapi.logistic": 1, + "nnapi.lsh_projection": 1, + "nnapi.lstm": 1, + "nnapi.max_pool_2d": 1, + "nnapi.mul": 1, + "nnapi.relu": 1, + "nnapi.relu1": 1, + "nnapi.relu6": 1, + "nnapi.reshape": 1, + "nnapi.resize_bilinear": 1, + "nnapi.rnn": 1, + "nnapi.softmax": 1, + "nnapi.space_to_depth": 1, + "nnapi.svdf": 1, + "nnapi.tanh": 1, + "nnapi.batch_to_space_nd": 2, + "nnapi.div": 2, + "nnapi.mean": 2, + "nnapi.pad": 2, + "nnapi.space_to_batch_nd": 2, + "nnapi.squeeze": 2, + "nnapi.strided_slice": 2, + "nnapi.sub": 2, + "nnapi.transpose": 2, + "nnapi.abs": 3, + "nnapi.argmax": 3, + "nnapi.argmin": 3, + "nnapi.axis_aligned_bbox_transform": 3, + "nnapi.bidirectional_sequence_lstm": 3, + "nnapi.bidirectional_sequence_rnn": 3, + "nnapi.box_with_nms_limit": 3, + "nnapi.cast": 3, + "nnapi.channel_shuffle": 3, + "nnapi.detection_postprocessing": 3, + "nnapi.equal": 3, + "nnapi.exp": 3, + "nnapi.expand_dims": 3, + "nnapi.gather": 3, + "nnapi.generate_proposals": 3, + "nnapi.greater": 3, + "nnapi.greater_equal": 3, + "nnapi.grouped_conv_2d": 3, + "nnapi.heatmap_max_keypoint": 3, + "nnapi.instance_normalization": 3, + "nnapi.less": 3, + "nnapi.less_equal": 3, + "nnapi.log": 3, + "nnapi.logical_and": 3, + "nnapi.logical_not": 3, + "nnapi.logical_or": 3, + "nnapi.log_softmax": 3, + "nnapi.maximum": 3, + "nnapi.minimum": 3, + "nnapi.neg": 3, + "nnapi.not_equal": 3, + "nnapi.pad_v2": 3, + "nnapi.pow": 3, + "nnapi.prelu": 3, + "nnapi.quantize": 3, + "nnapi.quantized_16bit_lstm": 3, + "nnapi.random_multinomial": 3, + "nnapi.reduce_all": 3, + "nnapi.reduce_any": 3, + "nnapi.reduce_max": 3, + "nnapi.reduce_min": 3, + "nnapi.reduce_prod": 3, + "nnapi.reduce_sum": 3, + "nnapi.roi_align": 3, + "nnapi.roi_pooling": 3, + "nnapi.rsqrt": 3, + "nnapi.select": 3, + "nnapi.sin": 3, + "nnapi.slice": 3, + "nnapi.split": 3, + "nnapi.sqrt": 3, + "nnapi.tile": 3, + "nnapi.topk_v2": 3, + "nnapi.transpose_conv_2d": 3, + "nnapi.unidirectional_sequence_lstm": 3, + "nnapi.unidirectional_sequence_rnn": 3, + "nnapi.resize_nearest_neighbor": 3, + "nnapi.quantized_lstm": 4, + "nnapi.if": 4, + "nnapi.while": 4, + "nnapi.elu": 4, + "nnapi.hard_swish": 4, + "nnapi.fill": 4, + "nnapi.rank": 4, + "nnapi.batch_matmul": 6, + "nnapi.pack": 6, + "nnapi.mirror_pad": 7, + "nnapi.reverse": 7, + } + return levels[pattern_name] + + +def partition_for_nnapi(mod: IRModule, feature_level: Optional[int] = None) -> IRModule: + """Partition the graph greedily offloading supported operators to NNAPI. + + Parameters + ---------- + mod : tvm.ir.IRModule + The module to run passes on. + feature_level : Optional[int] + The maximum NNAPI feature level. + + Returns + ------- + mod : tvm.ir.IRModule + Annotated and partitioned module. + """ + patterns = get_patterns_with_prefix("nnapi") + if feature_level is not None: + patterns = [pat for pat in patterns if feature_level >= min_feature_level(pat.name)] + mod = FuseOpsByPattern(patterns, bind_constants=False, annotate_codegen=False)(mod) + mod = MergeCompositeFunctions()(mod) + return mod diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index 8227530f7ab7..8b919d2c9dca 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -980,6 +980,12 @@ def _multi_gpu_exists(): target_kind_enabled="opencl", ) +# Mark a test as requiring NNAPI support in build. +requires_nnapi = Feature( + "NNAPI", + "NNAPI", + cmake_flag="USE_NNAPI_CODEGEN", +) # Mark a test as requiring microTVM to run requires_micro = Feature("micro", "MicroTVM", cmake_flag="USE_MICRO") diff --git a/src/relax/backend/contrib/nnapi/codegen.cc b/src/relax/backend/contrib/nnapi/codegen.cc new file mode 100644 index 000000000000..ef74cca70ee8 --- /dev/null +++ b/src/relax/backend/contrib/nnapi/codegen.cc @@ -0,0 +1,272 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "../../../transform/utils.h" +#include "../codegen_json/codegen_json.h" +#include "tvm/relax/attrs/manipulate.h" + +namespace tvm { +namespace relax { +namespace contrib { + +using JSONSerializer = backend::contrib::JSONSerializer; +using JSONGraphNode = backend::contrib::JSONGraphNode; +using JSONGraphNodeEntry = backend::contrib::JSONGraphNodeEntry; +using JSONGraphObjectPtr = backend::contrib::JSONGraphObjectPtr; +using NodeEntries = backend::contrib::NodeEntries; + +class NNAPIJSONSerializer; + +class CollectFromCompositeFunctionBody : public ExprVisitor { + public: + explicit CollectFromCompositeFunctionBody(NNAPIJSONSerializer* serializer) + : serializer_(serializer), node_(std::make_shared()) {} + + void VisitExpr_(const CallNode* call_node) override; + + void SetPermuteDimsAttribute(const CallNode* call_node) { + const auto* permute_dims_attr = call_node->attrs.as(); + ICHECK(permute_dims_attr); + if (permute_dims_attr->axes) { + std::vector axes; + for (auto axis : permute_dims_attr->axes.value()) { + axes.push_back(std::to_string(axis.IntValue())); + } + + std::vector axes_attr; + axes_attr.emplace_back(axes); + node_->SetAttr("axes", axes_attr); + } + } + + void SetAstypeAttribute(const CallNode* call_node) { + const auto* astype_attrs = call_node->attrs.as(); + ICHECK(astype_attrs); + + std::vector dtype_attr; + auto dtype_str = runtime::DLDataType2String(astype_attrs->dtype); + dtype_attr.emplace_back(std::vector{dtype_str}); + node_->SetAttr("astype_dtype", dtype_attr); + } + + void SetMeanAttribute(const CallNode* call_node) { + const auto* mean_attrs = call_node->attrs.as(); + ICHECK(mean_attrs); + ICHECK(mean_attrs->axis.defined()); + + { + std::vector axis; + for (auto dim : mean_attrs->axis.value()) { + axis.push_back(std::to_string(dim->value)); + } + + std::vector axis_attr; + axis_attr.emplace_back(axis); + node_->SetAttr("axis", axis_attr); + } + + { + const std::vector keepdims{mean_attrs->keepdims ? "1" : "0"}; + std::vector keepdims_attr; + keepdims_attr.emplace_back(keepdims); + node_->SetAttr("keepdims", keepdims_attr); + } + } + + void SetConv2dAttribute(const CallNode* call_node) { + const auto* conv2d_attr = call_node->attrs.as(); + ICHECK(conv2d_attr) << "didn't catch attributes"; + + std::vector strides; + if (!conv2d_attr->strides.empty()) { + for (auto stride : conv2d_attr->strides) { + const auto* stride_val = stride.as(); + ICHECK(stride_val) << "convertion failed"; + + strides.push_back(std::to_string(stride_val->value)); + } + } else { + strides = {"1", "1"}; + } + + std::vector padding; + for (auto pad : conv2d_attr->padding) { + const auto* padding_val = pad.as(); + + padding.push_back(std::to_string(padding_val->value)); + } + + std::vector groups; + const int group_val = conv2d_attr->groups; + groups.push_back(std::to_string(group_val)); + + std::vector strides_attr; + strides_attr.emplace_back(strides); + node_->SetAttr("strides", strides_attr); + + std::vector padding_attr; + padding_attr.emplace_back(padding); + node_->SetAttr("padding", padding_attr); + + std::vector group_attr; + group_attr.emplace_back(groups); + node_->SetAttr("group", group_attr); + } + + void SetMaxPool2dAttribute(const CallNode* call_node) { + const auto* max_pool_2d_attr = call_node->attrs.as(); + ICHECK(max_pool_2d_attr) << "didn't catch attributes"; + + std::vector strides; + if (!max_pool_2d_attr->strides.empty()) { + for (auto stride : max_pool_2d_attr->strides) { + const auto* stride_val = stride.as(); + ICHECK(stride_val) << "convertion failed"; + + strides.push_back(std::to_string(stride_val->value)); + } + } else { + strides.push_back("1"); + strides.push_back("1"); + } + + std::vector padding; + for (auto pad : max_pool_2d_attr->padding) { + const auto* padding_val = pad.as(); + + padding.push_back(std::to_string(padding_val->value)); + } + + std::vector pool_size; + for (auto size : max_pool_2d_attr->pool_size) { + const auto* pooling_val = size.as(); + + pool_size.push_back(std::to_string(pooling_val->value)); + } + + std::vector strides_attr; + strides_attr.emplace_back(strides); + node_->SetAttr("strides", strides_attr); + + std::vector padding_attr; + padding_attr.emplace_back(padding); + node_->SetAttr("padding", padding_attr); + + std::vector pooling_attr; + pooling_attr.emplace_back(pool_size); + node_->SetAttr("pool_size", pooling_attr); + } + + NNAPIJSONSerializer* serializer_; + JSONGraphObjectPtr node_; +}; + +class NNAPIJSONSerializer : public JSONSerializer { + public: + explicit NNAPIJSONSerializer(Map constant_names, Map bindings) + : JSONSerializer(constant_names), bindings_(bindings) {} + using JSONSerializer::VisitExpr_; + + std::vector VisitExpr_(const CallNode* call_node) final { + const auto* fn_var = call_node->op.as(); + ICHECK(fn_var); + const auto fn = Downcast(bindings_[GetRef(fn_var)]); + ICHECK(fn.defined()) << "Expects the callee to be a function."; + + auto composite_opt = fn->GetAttr(attr::kComposite); + ICHECK(composite_opt.defined()) << "Only composite functions are supported."; + + std::string composite_name = composite_opt.value(); + + CollectFromCompositeFunctionBody collector(this); + collector.VisitExpr(fn->body); + + NodeEntries inputs; + for (const auto& arg : call_node->args) { + auto res = VisitExpr(arg); + inputs.insert(inputs.end(), res.begin(), res.end()); + } + + auto node = std::make_shared(composite_name, /* name_ */ + "kernel", /* op_type_ */ + inputs, 1 /* num_outputs_ */); + node->CaptureAttrs(*collector.node_); + + VLOG(1) << "Adding node " << composite_name << " with " << node->GetInputs().size() + << " inputs"; + return AddNode(node, GetRef(call_node)); + } + + private: + Map bindings_; +}; + +void CollectFromCompositeFunctionBody::VisitExpr_(const CallNode* call_node) { + const auto* op_node = call_node->op.as(); + ICHECK(op_node != nullptr); + std::string name = op_node->name; + if (name == "relax.permute_dims") { + SetPermuteDimsAttribute(call_node); + } else if (name == "relax.astype") { + SetAstypeAttribute(call_node); + } else if (name == "relax.mean") { + SetMeanAttribute(call_node); + } else if (name == "relax.nn.conv2d") { + SetConv2dAttribute(call_node); + } else if (name == "relax.nn.max_pool2d") { + SetMaxPool2dAttribute(call_node); + } else { + } + ExprVisitor::VisitExpr_(call_node); +} + +Array NNAPICompiler(Array functions, Map /*unused*/, + Map constant_names) { + VLOG(1) << "NNAPI Compiler"; + + Array compiled_functions; + for (const auto& func : functions) { + NNAPIJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); + serializer.serialize(func); + auto graph_json = serializer.GetJSON(); + auto constant_names = serializer.GetConstantNames(); + const auto* pf = runtime::Registry::Get("runtime.nnapi_runtime_create"); + ICHECK(pf != nullptr) << "Cannot find NNAPI runtime module create function."; + auto func_name = GetExtSymbol(func); + compiled_functions.push_back((*pf)(func_name, graph_json, constant_names)); + } + + return compiled_functions; +} + +TVM_REGISTER_GLOBAL("relax.ext.nnapi").set_body_typed(NNAPICompiler); + +} // namespace contrib +} // namespace relax +} // namespace tvm diff --git a/src/runtime/contrib/nnapi/nnapi_builder.cc b/src/runtime/contrib/nnapi/nnapi_builder.cc new file mode 100644 index 000000000000..d43f00661de9 --- /dev/null +++ b/src/runtime/contrib/nnapi/nnapi_builder.cc @@ -0,0 +1,264 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifdef TVM_GRAPH_EXECUTOR_NNAPI + +#include "nnapi_builder.h" + +#include +#include + +#include +#include +#include + +#include "../json/json_runtime.h" +#include "nnapi_ops.h" + +namespace tvm { +namespace runtime { +namespace contrib { + +WrappedANeuralNetworksOperandType::WrappedANeuralNetworksOperandType( + int32_t tensor_type, std::vector dimensions, float scale, int32_t zero_point) + : dimensions_(dimensions) { + ty_.type = tensor_type; + if (dimensions_.empty()) { + ty_.dimensions = nullptr; + } else { + ty_.dimensions = dimensions_.data(); + } + ty_.dimensionCount = dimensions_.size(); + ty_.scale = scale; + ty_.zeroPoint = zero_point; +} + +WrappedANeuralNetworksOperandType::WrappedANeuralNetworksOperandType( + const WrappedANeuralNetworksOperandType& other) + : dimensions_(other.dimensions_), ty_(other.ty_) { + if (dimensions_.empty()) { + ty_.dimensions = nullptr; + } else { + ty_.dimensions = dimensions_.data(); + } +} + +WrappedANeuralNetworksOperandType& WrappedANeuralNetworksOperandType::operator=( + const WrappedANeuralNetworksOperandType& other) { + WrappedANeuralNetworksOperandType temp(other); + std::swap(*this, temp); + return *this; +} + +const ANeuralNetworksOperandType* WrappedANeuralNetworksOperandType::Get() const { return &ty_; } + +NNAPIOperand::NNAPIOperand(uint32_t index, const DLTensor* tensor) + : index_(index), scalar_(false), dimensions_(tensor->shape, tensor->shape + tensor->ndim) { + if (dimensions_.size() == 0) { + dimensions_.push_back(1); + } + + tensor_type_ = TensorTypeFromDLDataType(tensor->dtype); + scale_ = 0.0; + zero_point_ = 0; +} + +NNAPIOperand::NNAPIOperand(uint32_t index, const int64_t* shape, int ndim, DLDataType dtype) + : index_(index), scalar_(false), dimensions_(shape, shape + ndim) { + if (dimensions_.size() == 0) { + dimensions_.push_back(1); + } + + tensor_type_ = TensorTypeFromDLDataType(dtype); + scale_ = 0.0; + zero_point_ = 0; +} + +NNAPIOperand::NNAPIOperand(uint32_t index, int32_t tensor_type, std::vector dimensions, + float scale, int32_t zero_point) + : index_(index), + scalar_(false), + tensor_type_(tensor_type), + dimensions_(dimensions), + scale_(scale), + zero_point_(zero_point) { + if (dimensions_.size() == 0) { + dimensions_.push_back(1); + } +} + +NNAPIOperand NNAPIOperand::Scalar(uint32_t index, int32_t tensor_type, + std::vector dimensions, float scale, + int32_t zero_point) { + NNAPIOperand operand(index, tensor_type, dimensions, scale, zero_point); + operand.dimensions_.clear(); + operand.scalar_ = true; + return operand; +} + +void NNAPIOperand::SetDimensions(std::vector dimensions) { dimensions_ = dimensions; } + +WrappedANeuralNetworksOperandType NNAPIOperand::GetOperandType() const { + std::vector dimensions(dimensions_.begin(), dimensions_.end()); + return WrappedANeuralNetworksOperandType(tensor_type_, dimensions, scale_, zero_point_); +} + +uint32_t NNAPIOperand::GetOperandIndex() const { return index_; } + +const std::vector& NNAPIOperand::GetDimensions() const { return dimensions_; } +const float NNAPIOperand::GetScale() const { return scale_; } +const int32_t NNAPIOperand::GetZeroPoint() const { return zero_point_; } + +int32_t NNAPIOperand::GetTensorType() const { return tensor_type_; } +bool NNAPIOperand::IsDynamicShape() const { + return std::any_of(dimensions_.begin(), dimensions_.end(), [](int64_t dim) { return dim == -1; }); +} + +NNAPIModelBuilder::NNAPIModelBuilder() { + ICHECK_EQ(ANeuralNetworksModel_create(&model_), ANEURALNETWORKS_NO_ERROR); +} + +NNAPIModelBuilder::~NNAPIModelBuilder() { ANeuralNetworksModel_free(model_); } + +NNAPIOperand NNAPIModelBuilder::CreateOperandWithValue(const DLTensor& tensor) { + NNAPIOperand operand(next_operand_index_++, &tensor); + const size_t operand_data_size = GetDataSize(tensor); + + ICHECK_EQ(ANeuralNetworksModel_addOperand(model_, operand.GetOperandType().Get()), + ANEURALNETWORKS_NO_ERROR); + ICHECK_EQ(ANeuralNetworksModel_setOperandValue(model_, operand.GetOperandIndex(), tensor.data, + operand_data_size), + ANEURALNETWORKS_NO_ERROR); + + return operand; +} + +NNAPIOperand NNAPIModelBuilder::CreateOperandWithValue(int32_t tensor_type, + std::vector dimensions, float scale, + int32_t zero_point, const void* buffer, + size_t size) { + NNAPIOperand operand(next_operand_index_++, tensor_type, dimensions, scale, zero_point); + ICHECK_EQ(ANeuralNetworksModel_addOperand(model_, operand.GetOperandType().Get()), + ANEURALNETWORKS_NO_ERROR); + ICHECK_EQ(ANeuralNetworksModel_setOperandValue(model_, operand.GetOperandIndex(), buffer, size), + ANEURALNETWORKS_NO_ERROR); + return operand; +} + +NNAPIOperand NNAPIModelBuilder::CreateScalarOperandWithValue(OperandCode operand_code, + const void* buffer, size_t size) { + NNAPIOperand operand = NNAPIOperand::Scalar(next_operand_index_++, operand_code, {}, 0.0f, 0); + + ICHECK_EQ(ANeuralNetworksModel_addOperand(model_, operand.GetOperandType().Get()), + ANEURALNETWORKS_NO_ERROR); + ICHECK_EQ(ANeuralNetworksModel_setOperandValue(model_, operand.GetOperandIndex(), buffer, size), + ANEURALNETWORKS_NO_ERROR); + return operand; +} + +NNAPIOperand NNAPIModelBuilder::CreateOperand(const DLTensor& tensor) { + NNAPIOperand operand(next_operand_index_++, tensor.shape, tensor.ndim, tensor.dtype); + ICHECK_EQ(ANeuralNetworksModel_addOperand(model_, operand.GetOperandType().Get()), + ANEURALNETWORKS_NO_ERROR); + return operand; +} + +NNAPIOperand NNAPIModelBuilder::CreateOperand(const int64_t* shape, int ndim, DLDataType dtype) { + NNAPIOperand operand(next_operand_index_++, shape, ndim, dtype); + ICHECK_EQ(ANeuralNetworksModel_addOperand(model_, operand.GetOperandType().Get()), + ANEURALNETWORKS_NO_ERROR); + return operand; +} + +NNAPIOperand NNAPIModelBuilder::CreateOperand(int32_t tensor_type, std::vector dimensions, + float scale, int32_t zero_point) { + NNAPIOperand operand(next_operand_index_++, tensor_type, dimensions, scale, zero_point); + ICHECK_EQ(ANeuralNetworksModel_addOperand(model_, operand.GetOperandType().Get()), + ANEURALNETWORKS_NO_ERROR); + return operand; +} + +void NNAPIModelBuilder::AddOperation(ANeuralNetworksOperationType operation, + const std::vector input_indicies, + const std::vector output_indicies) { + ICHECK_EQ(ANeuralNetworksModel_addOperation(model_, operation, input_indicies.size(), + input_indicies.data(), output_indicies.size(), + output_indicies.data()), + ANEURALNETWORKS_NO_ERROR); +} + +void NNAPIModelBuilder::Finish(const std::vector& model_input_operands, + const std::vector& model_output_operands) { + const auto model_input_indices = ExtractOperandIndices(model_input_operands); + const auto model_output_indices = ExtractOperandIndices(model_output_operands); + ICHECK_EQ(ANeuralNetworksModel_identifyInputsAndOutputs( + model_, model_input_indices.size(), model_input_indices.data(), + model_output_indices.size(), model_output_indices.data()), + ANEURALNETWORKS_NO_ERROR); + ICHECK_EQ(ANeuralNetworksModel_finish(model_), ANEURALNETWORKS_NO_ERROR); +} + +ANeuralNetworksCompilation* NNAPIModelBuilder::Compile() { + ANeuralNetworksCompilation* compilation; + ICHECK_EQ(ANeuralNetworksCompilation_create(model_, &compilation), ANEURALNETWORKS_NO_ERROR); + ICHECK_EQ(ANeuralNetworksCompilation_setPreference(compilation, + ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER), + ANEURALNETWORKS_NO_ERROR); + ICHECK_EQ(ANeuralNetworksCompilation_finish(compilation), ANEURALNETWORKS_NO_ERROR); + return compilation; +} + +int32_t TensorTypeFromDLDataType(DLDataType ty) { + if (ty.code == kDLInt) { + if (ty.bits == 32) { + return ANEURALNETWORKS_TENSOR_INT32; + } else { + ICHECK(false) << "Unsupported bit width " << ty.bits << " for NNAPI integer tensor"; + } + } else if (ty.code == kDLUInt) { + if (ty.bits == 1) { + return ANEURALNETWORKS_TENSOR_BOOL8; + } else { + ICHECK(false) << "Unsupported bit width " << ty.bits << " for NNAPI unsigned integer tensor"; + } + } else if (ty.code == kDLFloat) { + if (ty.bits == 32) { + return ANEURALNETWORKS_TENSOR_FLOAT32; + } else if (ty.bits == 16) { + return ANEURALNETWORKS_TENSOR_FLOAT16; + } else { + ICHECK(false) << "Unsupported bit width " << ty.bits << " for NNAPI integer tensor"; + } + } else { + ICHECK(false) << "Unsupported DLDataTypeCode for NNAPI: " << ty.code; + } +} + +std::vector ExtractOperandIndices(const std::vector& operands) { + std::vector indices; + indices.reserve(operands.size()); + std::transform(operands.begin(), operands.end(), std::back_inserter(indices), + [](const NNAPIOperand& operand) { return operand.GetOperandIndex(); }); + return indices; +} + +} // namespace contrib +} // namespace runtime +} // namespace tvm +#endif // TVM_GRAPH_EXECUTOR_NNAPI diff --git a/src/runtime/contrib/nnapi/nnapi_builder.h b/src/runtime/contrib/nnapi/nnapi_builder.h new file mode 100644 index 000000000000..4360f50bf1e9 --- /dev/null +++ b/src/runtime/contrib/nnapi/nnapi_builder.h @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_RUNTIME_CONTRIB_NNAPI_NNAPI_BUILDER_H_ +#define TVM_RUNTIME_CONTRIB_NNAPI_NNAPI_BUILDER_H_ +#ifdef TVM_GRAPH_EXECUTOR_NNAPI + +#include +#include + +#include + +namespace tvm { +namespace runtime { +namespace contrib { + +class WrappedANeuralNetworksOperandType { + public: + WrappedANeuralNetworksOperandType(int32_t tensor_type, std::vector dimensions, + float scale, int32_t zero_point); + WrappedANeuralNetworksOperandType(const WrappedANeuralNetworksOperandType&); + WrappedANeuralNetworksOperandType& operator=(const WrappedANeuralNetworksOperandType&); + + const ANeuralNetworksOperandType* Get() const; + + private: + std::vector dimensions_; + ANeuralNetworksOperandType ty_; +}; + +class NNAPIOperand { + public: + NNAPIOperand(uint32_t index, const DLTensor* tensor); + NNAPIOperand(uint32_t index, const int64_t* shape, int ndim, DLDataType dtype); + NNAPIOperand(uint32_t index, int32_t tensor_type, std::vector dimensions, float scale, + int32_t zero_point); + static NNAPIOperand Scalar(uint32_t index, int32_t tensor_type, std::vector dimensions, + float scale, int32_t zero_point); + void SetDimensions(std::vector dimensions); + + WrappedANeuralNetworksOperandType GetOperandType() const; + uint32_t GetOperandIndex() const; + const std::vector& GetDimensions() const; + const float GetScale() const; + const int32_t GetZeroPoint() const; + int32_t GetTensorType() const; + bool IsDynamicShape() const; + + private: + uint32_t index_; + bool scalar_; + + // The NNAPI operand type e.g. ANEURALNETWORKS_TENSOR_INT32. + int32_t tensor_type_; + std::vector dimensions_; + float scale_; + int32_t zero_point_; +}; + +class NNAPIModelBuilder { + public: + NNAPIModelBuilder(); + ~NNAPIModelBuilder(); + NNAPIModelBuilder(const NNAPIModelBuilder&) = delete; + NNAPIModelBuilder& operator=(const NNAPIModelBuilder&) = delete; + inline NNAPIModelBuilder(NNAPIModelBuilder&& other) { + model_ = other.model_; + other.model_ = nullptr; + next_operand_index_ = other.next_operand_index_; + other.next_operand_index_ = 0; + } + inline NNAPIModelBuilder& operator=(NNAPIModelBuilder&& other) { + model_ = other.model_; + other.model_ = nullptr; + next_operand_index_ = other.next_operand_index_; + other.next_operand_index_ = 0; + return *this; + } + + NNAPIOperand CreateOperandWithValue(const DLTensor& tensor); + NNAPIOperand CreateOperandWithValue(int32_t tensor_type, std::vector dimensions, + float scale, int32_t zero_point, const void* buffer, + size_t size); + NNAPIOperand CreateScalarOperandWithValue(OperandCode operand_code, const void* buffer, + size_t size); + + NNAPIOperand CreateOperand(const DLTensor& tensor); + NNAPIOperand CreateOperand(const int64_t* shape, int ndim, DLDataType dtype); + NNAPIOperand CreateOperand(int32_t tensor_type, std::vector dimensions, float scale, + int32_t zero_point); + + void AddOperation(ANeuralNetworksOperationType operation, + const std::vector input_indices, + const std::vector output_indices); + + void Finish(const std::vector& model_input_operands, + const std::vector& model_output_operands); + ANeuralNetworksCompilation* Compile(); + + private: + ANeuralNetworksModel* model_; + uint32_t next_operand_index_ = 0; +}; + +/*! + * \brief Convert a DLDataType to an NNAPI OperandCode. + */ +int32_t TensorTypeFromDLDataType(DLDataType ty); + +std::vector ExtractOperandIndices(const std::vector& operands); + +} // namespace contrib +} // namespace runtime +} // namespace tvm + +#endif // TVM_GRAPH_EXECUTOR_NNAPI +#endif // TVM_RUNTIME_CONTRIB_NNAPI_NNAPI_BUILDER_H_ diff --git a/src/runtime/contrib/nnapi/nnapi_ops.cc b/src/runtime/contrib/nnapi/nnapi_ops.cc new file mode 100644 index 000000000000..ad055ec2c76f --- /dev/null +++ b/src/runtime/contrib/nnapi/nnapi_ops.cc @@ -0,0 +1,601 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifdef TVM_GRAPH_EXECUTOR_NNAPI +#include "nnapi_ops.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "nnapi_builder.h" + +namespace tvm { +namespace runtime { +namespace contrib { + +NNAPIOpConverterParams::NNAPIOpConverterParams(const JSONGraphNode& node) : node(node) {} + +NNAPIOpConverter::NNAPIOpConverter(std::string op_name) : op_name_(op_name) {} + +void ElwBinaryOpConverter::Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, + const std::vector& inputs, + std::vector& outputs) const { + // A map from op names to NNAPI OperationCode and whether it requires a FuseCode. + static const std::unordered_map> + op_map = { + {"add", {ANEURALNETWORKS_ADD, true}}, + {"mul", {ANEURALNETWORKS_MUL, true}}, + {"div", {ANEURALNETWORKS_DIV, true}}, + {"sub", {ANEURALNETWORKS_SUB, true}}, + {"pow", {ANEURALNETWORKS_POW, false}}, + {"equal", {ANEURALNETWORKS_EQUAL, false}}, + {"greater", {ANEURALNETWORKS_GREATER, false}}, + {"greater_equal", {ANEURALNETWORKS_GREATER_EQUAL, false}}, + {"less", {ANEURALNETWORKS_LESS, false}}, + {"less_equal", {ANEURALNETWORKS_LESS_EQUAL, false}}, + {"not_equal", {ANEURALNETWORKS_NOT_EQUAL, false}}, + {"maximum", {ANEURALNETWORKS_MAXIMUM, false}}, + {"minimum", {ANEURALNETWORKS_MINIMUM, false}}, + }; + + auto it = op_map.find(op_name_); + ICHECK(it != op_map.end()) << "Unsupported binary operation type " << op_name_; + const ANeuralNetworksOperationType operation_type = std::get<0>(it->second); + const bool requires_fuse_code = std::get<1>(it->second); + + ICHECK_EQ(inputs.size(), 2) << "Expected binary operation to have 2 inputs but got " + << inputs.size(); + + auto input_indices = ExtractOperandIndices(inputs); + const auto output_indices = ExtractOperandIndices(outputs); + + if (requires_fuse_code) { + // Create an extra input at index 2 for the fuse code. + const int32_t fused_none = ANEURALNETWORKS_FUSED_NONE; + const NNAPIOperand fuse_code_operand = builder.CreateScalarOperandWithValue( + ANEURALNETWORKS_INT32, &fused_none, sizeof(fused_none)); + input_indices.push_back(fuse_code_operand.GetOperandIndex()); + } + + builder.AddOperation(operation_type, input_indices, output_indices); +} + +void UnaryOpConverter::Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, + const std::vector& inputs, + std::vector& outputs) const { + static const std::unordered_map op_map = { + // clang-format off + {"floor", ANEURALNETWORKS_FLOOR}, + {"logistic", ANEURALNETWORKS_LOGISTIC}, + {"relu", ANEURALNETWORKS_RELU}, + {"tanh", ANEURALNETWORKS_TANH}, + {"abs", ANEURALNETWORKS_ABS}, + {"exp", ANEURALNETWORKS_EXP}, + {"log", ANEURALNETWORKS_LOG}, + {"neg", ANEURALNETWORKS_NEG}, + {"sqrt", ANEURALNETWORKS_SQRT}, + {"rsqrt", ANEURALNETWORKS_RSQRT}, + // clang-format on + }; + auto it = op_map.find(op_name_); + ICHECK(it != op_map.end()) << "Unsupported unary operation type " << op_name_; + const ANeuralNetworksOperationType operation_type = it->second; + + const auto input_indices = ExtractOperandIndices(inputs); + const auto output_indices = ExtractOperandIndices(outputs); + builder.AddOperation(operation_type, input_indices, output_indices); +} + +void SoftmaxOpConverter::Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, + const std::vector& inputs, + std::vector& outputs) const { + ICHECK_EQ(inputs.size(), 1) << "Unsupported number of inputs for NNAPI softmax operation: " + << inputs.size(); + + auto input_indices = ExtractOperandIndices(inputs); + const auto output_indices = ExtractOperandIndices(outputs); + + // Add the scalar input for beta value at index 1. + const auto& input = inputs[0]; + // TODO(PLLab): Conditionally use float16 beta for float16 input. + ICHECK_EQ(input.GetTensorType(), ANEURALNETWORKS_TENSOR_FLOAT32) + << "NNAPI runtime does not support non-float32 inputs for softmax yet"; + const float beta = 1.0f; + const NNAPIOperand beta_operand = + builder.CreateScalarOperandWithValue(ANEURALNETWORKS_FLOAT32, &beta, sizeof beta); + input_indices.push_back(beta_operand.GetOperandIndex()); + + builder.AddOperation(ANEURALNETWORKS_SOFTMAX, input_indices, output_indices); +} + +// Insert a reshape operation that reshapes `operand` to `dimensions` and return the reshaped +// operand. +NNAPIOperand ReshapeOperand(NNAPIModelBuilder& builder, const NNAPIOperand& operand, // NOLINT(*) + std::vector dimensions) { + // ANEURALNETWORKS_RESHAPE requires the dimensions to be specified in a int32 tensor. + const std::vector dimensions_int32(dimensions.begin(), dimensions.end()); + const std::vector dim_of_dims{static_cast(dimensions_int32.size())}; + + const NNAPIOperand reshape_shape_operand = + builder.CreateOperandWithValue(ANEURALNETWORKS_TENSOR_INT32, dim_of_dims, 0.0f, 0, + reinterpret_cast(dimensions_int32.data()), + dimensions_int32.size() * sizeof(*dimensions_int32.data())); + const NNAPIOperand reshaped_operand = builder.CreateOperand( + operand.GetTensorType(), dimensions, operand.GetScale(), operand.GetZeroPoint()); + + builder.AddOperation( + ANEURALNETWORKS_RESHAPE, + std::vector{operand.GetOperandIndex(), reshape_shape_operand.GetOperandIndex()}, + std::vector{reshaped_operand.GetOperandIndex()}); + return reshaped_operand; +} + +NNAPIOperand TransposeOperand(NNAPIModelBuilder& builder, const NNAPIOperand& operand, // NOLINT(*) + std::vector dimensions) { + const std::vector dimensions_int32(dimensions.begin(), dimensions.end()); + const std::vector dim_of_axes{static_cast(dimensions_int32.size())}; + std::vector result_dimension; + for (size_t i = 0; i < dimensions.size(); i++) { + result_dimension.push_back(operand.GetDimensions()[dimensions_int32[i]]); + } + + const NNAPIOperand transpose_shape_operand = + builder.CreateOperandWithValue(ANEURALNETWORKS_TENSOR_INT32, dim_of_axes, 0.0f, 0, + reinterpret_cast(dimensions_int32.data()), + dimensions_int32.size() * sizeof(*dimensions_int32.data())); + const NNAPIOperand transposed_operand = builder.CreateOperand( + operand.GetTensorType(), result_dimension, operand.GetScale(), operand.GetZeroPoint()); + + builder.AddOperation( + ANEURALNETWORKS_TRANSPOSE, + std::vector{operand.GetOperandIndex(), transpose_shape_operand.GetOperandIndex()}, + std::vector{transposed_operand.GetOperandIndex()}); + + return transposed_operand; +} + +void MatmulOpConverter::Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, + const std::vector& inputs, + std::vector& outputs) const { + ICHECK_EQ(inputs.size(), 2); + + auto input_indices = ExtractOperandIndices(inputs); + const auto output_indices = ExtractOperandIndices(outputs); + + const size_t input0_ndim = inputs[0].GetDimensions().size(); + const size_t input1_ndim = inputs[1].GetDimensions().size(); + if (input0_ndim != input1_ndim) { + if (input0_ndim > input1_ndim) { + // Check that the extra leading dimensions on input 0 are all ones. + const size_t diff = input0_ndim - input1_ndim; + for (size_t i = 0; i < diff; ++i) { + ICHECK_EQ(inputs[0].GetDimensions()[i], 1); + } + + // Expand input 1's dimensions. + std::vector reshaped_dimensions(diff, 1); + reshaped_dimensions.insert(reshaped_dimensions.end(), inputs[1].GetDimensions().begin(), + inputs[1].GetDimensions().end()); + const auto reshaped_operand = ReshapeOperand(builder, inputs[1], reshaped_dimensions); + input_indices[1] = reshaped_operand.GetOperandIndex(); + } else { + // input0_ndim < input1_ndim + // Check that the extra leading dimensions on input 1 are all ones. + const size_t diff = input1_ndim - input0_ndim; + for (size_t i = 0; i < diff; ++i) { + ICHECK_EQ(inputs[1].GetDimensions()[i], 1); + } + + // Expand input 0's dimensions. + std::vector reshaped_dimensions(diff, 1); + reshaped_dimensions.insert(reshaped_dimensions.end(), inputs[0].GetDimensions().begin(), + inputs[0].GetDimensions().end()); + const auto reshaped_operand = ReshapeOperand(builder, inputs[0], reshaped_dimensions); + input_indices[0] = reshaped_operand.GetOperandIndex(); + } + } + + { + const unsigned char adj_x = 0; + const NNAPIOperand adj_x_operand = + builder.CreateScalarOperandWithValue(ANEURALNETWORKS_BOOL, &adj_x, sizeof(adj_x)); + input_indices.push_back(adj_x_operand.GetOperandIndex()); + } + + { + const unsigned char adj_y = 0; + const NNAPIOperand adj_y_operand = + builder.CreateScalarOperandWithValue(ANEURALNETWORKS_BOOL, &adj_y, sizeof(adj_y)); + input_indices.push_back(adj_y_operand.GetOperandIndex()); + } + + builder.AddOperation(ANEURALNETWORKS_BATCH_MATMUL, input_indices, output_indices); +} + +void TransposeOpConverter::Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, + const std::vector& inputs, + std::vector& outputs) const { + ICHECK_EQ(inputs.size(), 1); + + auto input_indices = ExtractOperandIndices(inputs); + auto output_indices = ExtractOperandIndices(outputs); + + std::vector axes; + if (node.HasAttr("axes")) { + const auto axes_attr = node.GetAttr>("axes"); + for (auto str_axis : axes_attr) { + axes.push_back(std::stoi(str_axis)); + } + } else { + for (size_t i = 0; i < inputs[0].GetDimensions().size(); ++i) { + axes.push_back(i); + } + std::reverse(axes.begin(), axes.end()); + } + + const std::vector dim_of_axes{static_cast(axes.size())}; + const NNAPIOperand perm_operand = builder.CreateOperandWithValue( + ANEURALNETWORKS_TENSOR_INT32, dim_of_axes, 0.0f, 0, + reinterpret_cast(axes.data()), axes.size() * sizeof(*axes.data())); + input_indices.push_back(perm_operand.GetOperandIndex()); + + builder.AddOperation(ANEURALNETWORKS_TRANSPOSE, input_indices, output_indices); +} + +void CastOpConverter::Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, + const std::vector& inputs, + std::vector& outputs) const { + auto input_indices = ExtractOperandIndices(inputs); + auto output_indices = ExtractOperandIndices(outputs); + + // Extract the dtype attribute and check that the output operand type matches the dtype specified. + const auto dtype_attr = node.GetAttr>("astype_dtype"); + ICHECK(dtype_attr.size() == 1); + const auto dtype_str = dtype_attr[0]; + const DLDataType dtype = runtime::String2DLDataType(dtype_str); + ICHECK(outputs.size() == 1); + const auto output_tensor_type = outputs[0].GetTensorType(); + ICHECK(TensorTypeFromDLDataType(dtype) == output_tensor_type) + << "Expect a cast to dtype " << dtype_str << " but got output operand of type " + << output_tensor_type; + + builder.AddOperation(ANEURALNETWORKS_CAST, input_indices, output_indices); +} + +template +NNAPIOperand CreateConv2DBiasOperand(NNAPIModelBuilder& builder, // NOLINT(*) + int64_t output_depth) { + std::vector bias(output_depth, 0.0f); + + const std::vector dim_of_bias{static_cast(bias.size())}; + const NNAPIOperand bias_operand = builder.CreateOperandWithValue( + TensorType, dim_of_bias, 0.0f, 0, reinterpret_cast(bias.data()), + bias.size() * sizeof(*bias.data())); + return bias_operand; +} + +void Conv2dOpConverter::Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, + const std::vector& inputs, + std::vector& outputs) const { + auto input_indices = ExtractOperandIndices(inputs); + auto output_indices = ExtractOperandIndices(outputs); + + ICHECK(inputs.size() >= 2); + const auto input_tensor_type = inputs[0].GetTensorType(); + const auto filter_tensor_type = inputs[1].GetTensorType(); + ICHECK(input_tensor_type == filter_tensor_type); + ICHECK(input_tensor_type == ANEURALNETWORKS_TENSOR_FLOAT32 || + input_tensor_type == ANEURALNETWORKS_TENSOR_FLOAT16); + ICHECK(filter_tensor_type == ANEURALNETWORKS_TENSOR_FLOAT32 || + filter_tensor_type == ANEURALNETWORKS_TENSOR_FLOAT16); + + // transpose kernel + std::vector transposed_dimensions{0, 2, 3, 1}; + const auto transposed_operand = TransposeOperand(builder, inputs[1], transposed_dimensions); + + input_indices[1] = transposed_operand.GetOperandIndex(); + + // bias operand + if (input_indices.size() == 2) { + const int output_depth = inputs[1].GetDimensions()[0]; + if (input_tensor_type == ANEURALNETWORKS_TENSOR_FLOAT32) { + const NNAPIOperand bias_operand = + CreateConv2DBiasOperand(builder, output_depth); + input_indices.push_back(bias_operand.GetOperandIndex()); + } else if (input_tensor_type == ANEURALNETWORKS_TENSOR_FLOAT16) { + const NNAPIOperand bias_operand = + CreateConv2DBiasOperand(builder, output_depth); + input_indices.push_back(bias_operand.GetOperandIndex()); + } + } else { + int64_t bias_dim; + for (int i = 0; i < inputs[2].GetDimensions().size(); i++) { + if (inputs[2].GetDimensions()[i] != 1) { + bias_dim = inputs[2].GetDimensions()[i]; + } + } + std::vector bias_dimension = {bias_dim}; + NNAPIOperand bias_operand = ReshapeOperand(builder, inputs[2], bias_dimension); + input_indices[2] = bias_operand.GetOperandIndex(); + } + // padding operand + std::vector padding; + const auto padding_attr = node.GetAttr>("padding"); + + for (auto str_pad : padding_attr) { + padding.push_back(std::stoi(str_pad)); + } + + ICHECK(padding.size() == 4) << "NNAPI runtime currently only supports 4-way padding for Conv2D"; + const NNAPIOperand padding_left_operand = + builder.CreateScalarOperandWithValue(ANEURALNETWORKS_INT32, &padding[1], sizeof(padding[1])); + input_indices.push_back(padding_left_operand.GetOperandIndex()); + + const NNAPIOperand padding_right_operand = + builder.CreateScalarOperandWithValue(ANEURALNETWORKS_INT32, &padding[3], sizeof(padding[3])); + input_indices.push_back(padding_right_operand.GetOperandIndex()); + + const NNAPIOperand padding_top_operand = + builder.CreateScalarOperandWithValue(ANEURALNETWORKS_INT32, &padding[0], sizeof(padding[0])); + input_indices.push_back(padding_top_operand.GetOperandIndex()); + + const NNAPIOperand padding_bottom_operand = + builder.CreateScalarOperandWithValue(ANEURALNETWORKS_INT32, &padding[2], sizeof(padding[2])); + input_indices.push_back(padding_bottom_operand.GetOperandIndex()); + + // stride operand + std::vector stride; + const auto stride_attr = node.GetAttr>("strides"); + for (auto str_stride : stride_attr) { + stride.push_back(std::stoi(str_stride)); + } + + ICHECK(stride.size() == 2); + const NNAPIOperand stride_width_operand = + builder.CreateScalarOperandWithValue(ANEURALNETWORKS_INT32, &stride[0], sizeof(stride[0])); + input_indices.push_back(stride_width_operand.GetOperandIndex()); + + const NNAPIOperand stride_height_operand = + builder.CreateScalarOperandWithValue(ANEURALNETWORKS_INT32, &stride[1], sizeof(stride[1])); + input_indices.push_back(stride_height_operand.GetOperandIndex()); + + // group + int32_t group; + const auto group_attr = node.GetAttr>("group"); + for (auto str_group : group_attr) { + group = std::stoi(str_group); + } + + if (group > 1) { + const NNAPIOperand group_operand = + builder.CreateScalarOperandWithValue(ANEURALNETWORKS_INT32, &group, sizeof(group)); + input_indices.push_back(group_operand.GetOperandIndex()); + } + + // fuse code + const int32_t fused_none = ANEURALNETWORKS_FUSED_NONE; + const NNAPIOperand fuse_code_operand = + builder.CreateScalarOperandWithValue(ANEURALNETWORKS_INT32, &fused_none, sizeof(fused_none)); + input_indices.push_back(fuse_code_operand.GetOperandIndex()); + + // layout + // Use NCHW layout for input 0 and output 0. + const bool layout = true; + const NNAPIOperand layout_operand = + builder.CreateScalarOperandWithValue(ANEURALNETWORKS_BOOL, &layout, sizeof(layout)); + input_indices.push_back(layout_operand.GetOperandIndex()); + + if (group > 1) { + builder.AddOperation(ANEURALNETWORKS_GROUPED_CONV_2D, input_indices, output_indices); + } else { + builder.AddOperation(ANEURALNETWORKS_CONV_2D, input_indices, output_indices); + } +} + +void MaxPool2dOpConverter::Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, + const std::vector& inputs, + std::vector& outputs) const { + auto input_indices = ExtractOperandIndices(inputs); + auto output_indices = ExtractOperandIndices(outputs); + + // padding operand + std::vector padding; + const auto padding_attr = node.GetAttr>("padding"); + + for (auto str_pad : padding_attr) { + padding.push_back(std::stoi(str_pad)); + } + + const NNAPIOperand padding_left_operand = + builder.CreateScalarOperandWithValue(ANEURALNETWORKS_INT32, &padding[1], sizeof(padding[1])); + input_indices.push_back(padding_left_operand.GetOperandIndex()); + + const NNAPIOperand padding_right_operand = + builder.CreateScalarOperandWithValue(ANEURALNETWORKS_INT32, &padding[3], sizeof(padding[3])); + input_indices.push_back(padding_right_operand.GetOperandIndex()); + + const NNAPIOperand padding_top_operand = + builder.CreateScalarOperandWithValue(ANEURALNETWORKS_INT32, &padding[0], sizeof(padding[0])); + input_indices.push_back(padding_top_operand.GetOperandIndex()); + + const NNAPIOperand padding_bottom_operand = + builder.CreateScalarOperandWithValue(ANEURALNETWORKS_INT32, &padding[2], sizeof(padding[2])); + input_indices.push_back(padding_bottom_operand.GetOperandIndex()); + + // stride operand + std::vector stride; + const auto stride_attr = node.GetAttr>("strides"); + for (auto str_stride : stride_attr) { + stride.push_back(std::stoi(str_stride)); + } + + const NNAPIOperand stride_width_operand = + builder.CreateScalarOperandWithValue(ANEURALNETWORKS_INT32, &stride[0], sizeof(stride[0])); + input_indices.push_back(stride_width_operand.GetOperandIndex()); + + const NNAPIOperand stride_height_operand = + builder.CreateScalarOperandWithValue(ANEURALNETWORKS_INT32, &stride[1], sizeof(stride[1])); + input_indices.push_back(stride_height_operand.GetOperandIndex()); + + // filter operand + std::vector pool_size; + const auto pool_size_attr = node.GetAttr>("pool_size"); + for (auto size : pool_size_attr) { + pool_size.push_back(std::stoi(size)); + } + + const NNAPIOperand pool_size_width_operand = builder.CreateScalarOperandWithValue( + ANEURALNETWORKS_INT32, &pool_size[0], sizeof(pool_size[0])); + input_indices.push_back(pool_size_width_operand.GetOperandIndex()); + + const NNAPIOperand pool_size_height_operand = builder.CreateScalarOperandWithValue( + ANEURALNETWORKS_INT32, &pool_size[1], sizeof(pool_size[1])); + input_indices.push_back(pool_size_height_operand.GetOperandIndex()); + + // fuse code + const int32_t fused_none = ANEURALNETWORKS_FUSED_NONE; + const NNAPIOperand fuse_code_operand = + builder.CreateScalarOperandWithValue(ANEURALNETWORKS_INT32, &fused_none, sizeof(fused_none)); + input_indices.push_back(fuse_code_operand.GetOperandIndex()); + + // layout + const bool layout = true; + const NNAPIOperand layout_operand = + builder.CreateScalarOperandWithValue(ANEURALNETWORKS_BOOL, &layout, sizeof(layout)); + input_indices.push_back(layout_operand.GetOperandIndex()); + + builder.AddOperation(ANEURALNETWORKS_MAX_POOL_2D, input_indices, output_indices); +} + +void DenseOpConverter::Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, + const std::vector& inputs, + std::vector& outputs) const { + auto input_indices = ExtractOperandIndices(inputs); + auto output_indices = ExtractOperandIndices(outputs); + const auto input_tensor_type = inputs[0].GetTensorType(); + const auto filter_tensor_type = inputs[1].GetTensorType(); + ICHECK(input_tensor_type == filter_tensor_type); + ICHECK(input_tensor_type == ANEURALNETWORKS_TENSOR_FLOAT32 || + input_tensor_type == ANEURALNETWORKS_TENSOR_FLOAT16); + ICHECK(filter_tensor_type == ANEURALNETWORKS_TENSOR_FLOAT32 || + filter_tensor_type == ANEURALNETWORKS_TENSOR_FLOAT16); + + if (input_indices.size() == 2) { + const int output_depth = inputs[1].GetDimensions()[0]; + if (input_tensor_type == ANEURALNETWORKS_TENSOR_FLOAT32) { + const NNAPIOperand bias_operand = + CreateConv2DBiasOperand(builder, output_depth); + input_indices.push_back(bias_operand.GetOperandIndex()); + } else if (input_tensor_type == ANEURALNETWORKS_TENSOR_FLOAT16) { + const NNAPIOperand bias_operand = + CreateConv2DBiasOperand(builder, output_depth); + input_indices.push_back(bias_operand.GetOperandIndex()); + } + } + + // fuse code + const int32_t fused_none = ANEURALNETWORKS_FUSED_NONE; + const NNAPIOperand fuse_code_operand = + builder.CreateScalarOperandWithValue(ANEURALNETWORKS_INT32, &fused_none, sizeof(fused_none)); + input_indices.push_back(fuse_code_operand.GetOperandIndex()); + + builder.AddOperation(ANEURALNETWORKS_FULLY_CONNECTED, input_indices, output_indices); +} + +void MeanOpConverter::Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, + const std::vector& inputs, + std::vector& outputs) const { + auto input_indices = ExtractOperandIndices(inputs); + auto output_indices = ExtractOperandIndices(outputs); + + // Extract the axis attribute and create an operand for it. + const auto axis_attr = node.GetAttr>("axis"); + std::vector axis; + for (auto dim : axis_attr) { + axis.push_back(std::stoi(dim)); + } + const std::vector dim_of_axis{static_cast(axis.size())}; + + const NNAPIOperand axis_operand = builder.CreateOperandWithValue( + ANEURALNETWORKS_TENSOR_INT32, dim_of_axis, 0.0f, 0, + reinterpret_cast(axis.data()), axis.size() * sizeof(*axis.data())); + input_indices.push_back(axis_operand.GetOperandIndex()); + + // Extract the keepdims attribute and create an operand for it. + const auto keepdims_attr = node.GetAttr>("keepdims"); + ICHECK(keepdims_attr.size() == 1); + const int32_t keepdims = keepdims_attr[0] == "1"; + + const NNAPIOperand keepdims_operand = + builder.CreateScalarOperandWithValue(ANEURALNETWORKS_INT32, &keepdims, sizeof keepdims); + input_indices.push_back(keepdims_operand.GetOperandIndex()); + + builder.AddOperation(ANEURALNETWORKS_MEAN, input_indices, output_indices); +} + +const std::unordered_map>& GetOpConverters() { + static const std::unordered_map> map = []() { + std::unordered_map> map; + map.emplace("nnapi.add", std::make_unique("add")); + map.emplace("nnapi.mul", std::make_unique("mul")); + map.emplace("nnapi.div", std::make_unique("div")); + map.emplace("nnapi.sub", std::make_unique("sub")); + map.emplace("nnapi.pow", std::make_unique("pow")); + map.emplace("nnapi.equal", std::make_unique("equal")); + map.emplace("nnapi.greater", std::make_unique("greater")); + map.emplace("nnapi.greater_equal", std::make_unique("greater_equal")); + map.emplace("nnapi.less", std::make_unique("less")); + map.emplace("nnapi.less_equal", std::make_unique("less_equal")); + map.emplace("nnapi.not_equal", std::make_unique("not_equal")); + map.emplace("nnapi.maximum", std::make_unique("maximum")); + map.emplace("nnapi.minimum", std::make_unique("minimum")); + map.emplace("nnapi.floor", std::make_unique("floor")); + map.emplace("nnapi.logistic", std::make_unique("logistic")); + map.emplace("nnapi.relu", std::make_unique("relu")); + map.emplace("nnapi.tanh", std::make_unique("tanh")); + map.emplace("nnapi.abs", std::make_unique("abs")); + map.emplace("nnapi.exp", std::make_unique("exp")); + map.emplace("nnapi.log", std::make_unique("log")); + map.emplace("nnapi.neg", std::make_unique("neg")); + map.emplace("nnapi.sqrt", std::make_unique("sqrt")); + map.emplace("nnapi.rsqrt", std::make_unique("rsqrt")); + map.emplace("nnapi.softmax", std::make_unique()); + map.emplace("nnapi.batch_matmul", std::make_unique()); + map.emplace("nnapi.transpose", std::make_unique()); + map.emplace("nnapi.cast", std::make_unique("cast")); + map.emplace("nnapi.mean", std::make_unique("mean")); + map.emplace("nnapi.conv2d", std::make_unique()); + map.emplace("nnapi.fully_connected", std::make_unique()); + map.emplace("nnapi.max_pool_2d", std::make_unique()); + return map; + }(); + return map; +} + +} // namespace contrib +} // namespace runtime +} // namespace tvm +#endif // TVM_GRAPH_EXECUTOR_NNAPI diff --git a/src/runtime/contrib/nnapi/nnapi_ops.h b/src/runtime/contrib/nnapi/nnapi_ops.h new file mode 100644 index 000000000000..748a0b1d526c --- /dev/null +++ b/src/runtime/contrib/nnapi/nnapi_ops.h @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_RUNTIME_CONTRIB_NNAPI_NNAPI_OPS_H_ +#define TVM_RUNTIME_CONTRIB_NNAPI_NNAPI_OPS_H_ +#ifdef TVM_GRAPH_EXECUTOR_NNAPI + +#include + +#include +#include +#include +#include + +#include "../json/json_node.h" +#include "nnapi_builder.h" + +namespace tvm { +namespace runtime { +namespace contrib { + +using JSONGraphNode = tvm::runtime::json::JSONGraphNode; + +struct NNAPIOpConverterParams { + const JSONGraphNode& node; + std::vector inputs; + std::vector outputs; + explicit NNAPIOpConverterParams(const JSONGraphNode& node); +}; + +class NNAPIOpConverter { + public: + std::string op_name_; + + explicit NNAPIOpConverter(std::string op_name); + virtual ~NNAPIOpConverter() = default; + + virtual void Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, // NOLINT(*) + const std::vector& inputs, + std::vector& outputs) const = 0; // NOLINT(*) +}; + +class ElwBinaryOpConverter : public NNAPIOpConverter { + public: + inline explicit ElwBinaryOpConverter(std::string op_name) : NNAPIOpConverter(op_name) {} + ~ElwBinaryOpConverter() = default; + + void Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, + const std::vector& inputs, + std::vector& outputs) const override; +}; + +class UnaryOpConverter : public NNAPIOpConverter { + public: + inline explicit UnaryOpConverter(std::string op_name) : NNAPIOpConverter(op_name) {} + ~UnaryOpConverter() = default; + + void Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, + const std::vector& inputs, + std::vector& outputs) const override; +}; + +class SoftmaxOpConverter : public NNAPIOpConverter { + public: + inline SoftmaxOpConverter() : NNAPIOpConverter("softmax") {} + ~SoftmaxOpConverter() = default; + + void Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, + const std::vector& inputs, + std::vector& outputs) const override; +}; + +class MatmulOpConverter : public NNAPIOpConverter { + public: + inline MatmulOpConverter() : NNAPIOpConverter("") {} + ~MatmulOpConverter() = default; + + void Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, + const std::vector& inputs, + std::vector& outputs) const override; +}; + +class TransposeOpConverter : public NNAPIOpConverter { + public: + inline TransposeOpConverter() : NNAPIOpConverter("") {} + ~TransposeOpConverter() = default; + + void Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, + const std::vector& inputs, + std::vector& outputs) const override; +}; + +class CastOpConverter : public NNAPIOpConverter { + public: + inline explicit CastOpConverter(std::string op_name) : NNAPIOpConverter(op_name) {} + ~CastOpConverter() = default; + + void Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, + const std::vector& inputs, + std::vector& outputs) const override; +}; +class Conv2dOpConverter : public NNAPIOpConverter { + public: + inline Conv2dOpConverter() : NNAPIOpConverter("") {} + ~Conv2dOpConverter() = default; + + void Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, + const std::vector& inputs, + std::vector& outputs) const override; +}; + +class DenseOpConverter : public NNAPIOpConverter { + public: + inline DenseOpConverter() : NNAPIOpConverter("") {} + ~DenseOpConverter() = default; + + void Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, + const std::vector& inputs, + std::vector& outputs) const override; +}; + +class MaxPool2dOpConverter : public NNAPIOpConverter { + public: + inline MaxPool2dOpConverter() : NNAPIOpConverter("") {} + ~MaxPool2dOpConverter() = default; + + void Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, + const std::vector& inputs, + std::vector& outputs) const override; +}; + +class MeanOpConverter : public NNAPIOpConverter { + public: + inline explicit MeanOpConverter(std::string op_name) : NNAPIOpConverter(op_name) {} + ~MeanOpConverter() = default; + + void Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, + const std::vector& inputs, + std::vector& outputs) const override; +}; + +const std::unordered_map>& GetOpConverters(); + +} // namespace contrib +} // namespace runtime +} // namespace tvm + +#endif // TVM_GRAPH_EXECUTOR_NNAPI +#endif // TVM_RUNTIME_CONTRIB_NNAPI_NNAPI_OPS_H_ diff --git a/src/runtime/contrib/nnapi/nnapi_runtime.cc b/src/runtime/contrib/nnapi/nnapi_runtime.cc new file mode 100644 index 000000000000..c63098873da1 --- /dev/null +++ b/src/runtime/contrib/nnapi/nnapi_runtime.cc @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "../json/json_node.h" +#include "../json/json_runtime.h" + +#ifdef TVM_GRAPH_EXECUTOR_NNAPI +#include +#include + +#include "nnapi_builder.h" +#include "nnapi_ops.h" +#endif + +namespace tvm { +namespace runtime { +namespace contrib { + +using namespace tvm::runtime::json; +using JSONGraphNode = tvm::runtime::json::JSONGraphNode; + +class NNAPIRuntime : public JSONRuntimeBase { + public: + explicit NNAPIRuntime(const std::string& symbol_name, const std::string& graph_json, + const Array& const_names) + : JSONRuntimeBase(symbol_name, graph_json, const_names) {} + + const char* type_key() const final { return "nnapi"; } + +#ifdef TVM_GRAPH_EXECUTOR_NNAPI + struct CompiledModel { + CompiledModel(NNAPIModelBuilder builder, ANeuralNetworksCompilation* compilation, + std::vector model_output_operands) + : builder(std::move(builder)), + compilation(compilation), + model_output_operands(model_output_operands) {} + NNAPIModelBuilder builder; + ANeuralNetworksCompilation* compilation; + std::vector model_output_operands; + }; + + std::optional compiled_model_; + + void Init(const Array& consts) final { + ICHECK_EQ(consts.size(), const_idx_.size()) + << "The number of input constants must match the number of required constants."; + SetupConstants(consts); + CompileModel(); + } + + void CompileModel() { + NNAPIModelBuilder builder; + + // Clear the map, otherwise the input shapes from last inference gets used. + node_output_map_.clear(); + + // Add inputs as NNAPI model operands. + std::vector model_input_operands; + for (size_t i = 0; i < input_nodes_.size(); ++i) { + const uint32_t nid = input_nodes_[i]; + if (nodes_[nid].GetOpType() == "input") { + for (size_t j = 0; j < nodes_[nid].GetOpShape().size(); ++j) { + const std::vector input_shape = nodes_[nid].GetOpShape()[j]; + const auto input_dtype = nodes_[nid].GetOpDataType()[j]; + const NNAPIOperand operand = + builder.CreateOperand(input_shape.data(), input_shape.size(), input_dtype); + node_output_map_.emplace(nid, operand); + model_input_operands.push_back(operand); + } + } + } + + // Add kernels as NNAPI operations. + for (size_t nid = 0; nid < nodes_.size(); ++nid) { + const auto& node = nodes_[nid]; + if (node.GetOpType() != "kernel") { + continue; + } + AddOperation(builder, nid, node); + } + + // Collect the output operands indices. + std::vector model_output_operands; + for (size_t i = 0; i < outputs_.size(); ++i) { + const auto& node = outputs_[i]; + auto it = node_output_map_.find(node.id_); + ICHECK(it != node_output_map_.end()) << "Missing model output."; + const auto& operand = it->second; + model_output_operands.push_back(operand); + } + + // Finish and compile the model. + builder.Finish(model_input_operands, model_output_operands); + ANeuralNetworksCompilation* compilation = builder.Compile(); + + // Store the compilation + compiled_model_.emplace(std::move(builder), compilation, model_output_operands); + } + + void ExecuteModel(ANeuralNetworksCompilation* compilation, + const std::vector& model_output_operands) { + // Execute the model. + ANeuralNetworksExecution* execution; + ICHECK_EQ(ANeuralNetworksExecution_create(compilation, &execution), ANEURALNETWORKS_NO_ERROR); + + for (size_t i = 0; i < input_nodes_.size(); ++i) { + const uint32_t nid = input_nodes_[i]; + if (nodes_[nid].GetOpType() == "input") { + for (size_t j = 0; j < nodes_[nid].GetOpShape().size(); ++j) { + auto it = node_output_map_.find(nid); + ICHECK(it != node_output_map_.end()) << "Missing model input."; + const auto& operand = it->second; + + const uint32_t eid = EntryID(nid, j); + const auto entry = data_entry_[eid]; + + const auto operand_data_size = GetDataSize(*entry); + ICHECK_EQ(ANeuralNetworksExecution_setInput(execution, i, operand.GetOperandType().Get(), + entry->data, operand_data_size), + ANEURALNETWORKS_NO_ERROR); + } + } + } + + for (size_t i = 0; i < outputs_.size(); ++i) { + const auto& operand = model_output_operands[i]; + const auto& node = outputs_[i]; + + const auto eid = EntryID(node); + const auto entry = data_entry_[eid]; + + const auto operand_data_size = GetDataSize(*entry); + ICHECK_EQ(ANeuralNetworksExecution_setOutput(execution, i, operand.GetOperandType().Get(), + entry->data, operand_data_size), + ANEURALNETWORKS_NO_ERROR); + } + + ANeuralNetworksEvent* compute_event; + ICHECK_EQ(ANeuralNetworksExecution_startCompute(execution, &compute_event), + ANEURALNETWORKS_NO_ERROR); + ICHECK_EQ(ANeuralNetworksEvent_wait(compute_event), ANEURALNETWORKS_NO_ERROR); + ANeuralNetworksEvent_free(compute_event); + + ANeuralNetworksExecution_free(execution); + } + + void Run() final { + ICHECK(compiled_model_.has_value()); + CompiledModel& compiled_model = compiled_model_.value(); + ExecuteModel(compiled_model.compilation, compiled_model.model_output_operands); + } + + void AddOperation(NNAPIModelBuilder& builder, uint32_t nid, // NOLINT(*) + const JSONGraphNode& node) { + std::vector inputs; + std::vector outputs; + + // Map the op name to its converter. + const auto& converter_map = GetOpConverters(); + auto it = converter_map.find(node.GetOpName()); + ICHECK(it != converter_map.end()) << node.GetOpName() << ": Unsupported operation name"; + const NNAPIOpConverter& converter = *it->second; + + // Add input operands to params. + for (size_t i = 0; i < node.GetInputs().size(); ++i) { + auto in_node = node.GetInputs()[i]; + auto it = node_output_map_.find(in_node.id_); + ICHECK(it != node_output_map_.end()) << node.GetOpName() << ": Missing input"; + auto& operand = it->second; + inputs.push_back(operand); + } + + // Create and add output operands to params. + const auto output_shapes = node.GetOpShape(); + const auto output_dtypes = node.GetOpDataType(); + ICHECK(output_shapes.size() == output_dtypes.size()) + << "The number of output shapes must match the number of output dtypes"; + ICHECK(output_shapes.size() == 1) + << "NNAPI runtime currently does not support more than one output per operation yet"; + + for (size_t i = 0; i < output_shapes.size(); ++i) { + auto output_shape = output_shapes[i]; + const NNAPIOperand output_operand = + builder.CreateOperand(output_shape.data(), output_shape.size(), output_dtypes[i]); + outputs.push_back(output_operand); + } + + converter.Convert(builder, node, inputs, outputs); + + // Record the final output shape. + node_output_map_.emplace(nid, outputs[0]); + } + + private: + // Mapping from JSON node IDs to NNAPI operand numbers. + std::unordered_map node_output_map_; + +#else // ifdef TVM_GRAPH_EXECUTOR_NNAPI + void Init(const Array& consts) final { + LOG(FATAL) << "NNAPI runtime is not enabled. Build with USE_NNAPI_RUNTIME to enable it."; + } + + void Run() final { + LOG(FATAL) << "NNAPI runtime is not enabled. Build with USE_NNAPI_RUNTIME to enable it."; + } +#endif // ifdef TVM_GRAPH_EXECUTOR_NNAPI +}; + +runtime::Module NNAPIRuntimeCreate(const String& symbol_name, const String& graph_json, + const Array& const_names) { + auto n = make_object(symbol_name, graph_json, const_names); + return runtime::Module(n); +} + +TVM_REGISTER_GLOBAL("runtime.nnapi_runtime_create").set_body_typed(NNAPIRuntimeCreate); + +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_nnapi") + .set_body_typed(JSONRuntimeBase::LoadFromBinary); + +} // namespace contrib +} // namespace runtime +} // namespace tvm diff --git a/src/support/libinfo.cc b/src/support/libinfo.cc index 73800338b143..2d1c33cbf282 100644 --- a/src/support/libinfo.cc +++ b/src/support/libinfo.cc @@ -279,6 +279,14 @@ #define TVM_INFO_USE_NVSHMEM "NOT-FOUND" #endif +#ifndef TVM_INFO_USE_NNAPI_CODEGEN +#define TVM_INFO_USE_NNAPI_CODEGEN "NOT-FOUND" +#endif + +#ifndef TVM_INFO_USE_NNAPI_RUNTIME +#define TVM_INFO_USE_NNAPI_RUNTIME "NOT-FOUND" +#endif + namespace tvm { /*! @@ -392,6 +400,8 @@ TVM_DLL Map GetLibInfo() { {"USE_MSC", TVM_INFO_USE_MSC}, {"USE_CCACHE", TVM_INFO_USE_CCACHE}, {"USE_NVSHMEM", TVM_INFO_USE_NVSHMEM}, + {"USE_NNAPI_CODEGEN", TVM_INFO_USE_NNAPI_CODEGEN}, + {"USE_NNAPI_RUNTIME", TVM_INFO_USE_NNAPI_RUNTIME}, {"BACKTRACE_ON_SEGFAULT", TVM_INFO_BACKTRACE_ON_SEGFAULT}, }; return result; diff --git a/tests/python/nightly/test_nnapi/__init__.py b/tests/python/nightly/test_nnapi/__init__.py new file mode 100644 index 000000000000..b2606427b1d8 --- /dev/null +++ b/tests/python/nightly/test_nnapi/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Infrastructure and tests for NNAPI""" diff --git a/tests/python/nightly/test_nnapi/conftest.py b/tests/python/nightly/test_nnapi/conftest.py new file mode 100644 index 000000000000..abed80995a59 --- /dev/null +++ b/tests/python/nightly/test_nnapi/conftest.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os + +import pytest + +from tvm import rpc + + +def remote(): + if ( + "TVM_TRACKER_HOST" in os.environ + and "TVM_TRACKER_PORT" in os.environ + and "RPC_DEVICE_KEY" in os.environ + ): + + rpc_tracker_host = os.environ["TVM_TRACKER_HOST"] + rpc_tracker_port = int(os.environ["TVM_TRACKER_PORT"]) + rpc_device_key = os.environ["RPC_DEVICE_KEY"] + tracker = rpc.connect_tracker(rpc_tracker_host, rpc_tracker_port) + remote = tracker.request(rpc_device_key, priority=0, session_timeout=600) + return remote, tracker + else: + return None diff --git a/tests/python/nightly/test_nnapi/infrastructure.py b/tests/python/nightly/test_nnapi/infrastructure.py new file mode 100644 index 000000000000..aa5580c375ae --- /dev/null +++ b/tests/python/nightly/test_nnapi/infrastructure.py @@ -0,0 +1,143 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np + +import tvm +import tvm.script.relax as R + +# from tvm.contrib.debugger import debug_runtime as graph_executor +from tvm.contrib import ndk, utils +from tvm.relax.backend.contrib.nnapi import partition_for_nnapi + + +# pylint: disable=import-outside-toplevel,missing-function-docstring +def reshape_matmul(mod: tvm.IRModule): + from typing import Dict + + from tvm.relax import Expr + from tvm.relax.dpl import DFPattern, rewrite_call + from tvm.relax.dpl.pattern import is_op, wildcard + + input0 = wildcard() + input1 = wildcard() + pattern = is_op("relax.matmul")(input0, input1) + + def _rewriter(expr: Expr, matches: Dict[DFPattern, Expr]): + i0 = matches[input0] + i1 = matches[input1] + if len(i0.struct_info.shape) == 2 and len(i1.struct_info.shape) == 2: + i0_shape = [1] + [*i0.struct_info.shape.values] + i1_shape = [1] + [*i1.struct_info.shape.values] + oshape = matches[pattern].struct_info.shape + return R.reshape(R.matmul(R.reshape(i0, i0_shape), R.reshape(i1, i1_shape)), oshape) + return expr + + mod["main"] = rewrite_call(pattern, _rewriter, mod["main"]) + return mod + + +def decompose_clip(mod: tvm.IRModule) -> tvm.IRModule: + from typing import Dict + + from tvm.relax import Expr + from tvm.relax.dpl import DFPattern, rewrite_call + from tvm.relax.dpl.pattern import is_op, wildcard + + input_pattern = wildcard() + min_pattern = wildcard() + max_pattern = wildcard() + pattern = is_op("relax.clip")(input_pattern, min_pattern, max_pattern) + + def _rewriter( + expr: Expr, matches: Dict[DFPattern, Expr] + ) -> Expr: # pylint: disable=unused-argument + dtype = matches[input_pattern].struct_info.dtype + return R.minimum( + R.maximum( + matches[input_pattern], + R.const(np.array(matches[min_pattern].value.value).astype(dtype), dtype), + ), + R.const(np.array(matches[max_pattern].value.value).astype(dtype), dtype), + ) + + mod["main"] = rewrite_call(pattern, _rewriter, mod["main"]) + return mod + + +def _build(mod, enable_nnapi): + if isinstance(mod, tvm.relay.expr.Call): + mod = tvm.IRModule.from_expr(mod) + + if enable_nnapi: + mod = tvm.relax.transform.FoldConstant()(mod) + mod = reshape_matmul(mod) + mod = decompose_clip(mod) + mod = partition_for_nnapi(mod) + + mod = tvm.relax.transform.RunCodegen()(mod) + ex = tvm.relax.build(mod, target="llvm -mtriple=aarch64-linux-android") + + return ex + + +def _run(remote, tracker, ex, inputs): + + tmp = utils.tempdir() + so_name = "test_mod.so" + so_path = tmp / so_name + ex.export_library(str(so_path), fcompile=ndk.create_shared, options=["-shared", "-fPIC", "-lm"]) + + remote.upload(so_path) + dev = remote.cpu(0) + + try: + + # Execute the model on the remote. + remote_ex = remote.load_module(so_name) + vm = tvm.relax.VirtualMachine(remote_ex, device=dev) + + inputs = [x.copyto(dev) for x in inputs] + + vm.set_input("main", *inputs) + vm.invoke_stateful("main") + output = vm.get_outputs("main") + output = output.numpy() + except Exception as e: + # Re-raise all exceptions + raise e + finally: + # Manually close the connection. + # See https://discuss.tvm.apache.org/t/trouble-with-rpc-session/14008/. + # + # TODO: Remove if it does not happen on Python 3.11. + remote._sess.get_function("CloseRPCConnection")() + tracker.close() + pass + + return output + + +def build_and_run( + remote, + tracker, + mod, + inputs, + enable_nnapi=False, +): + ex = _build(mod, enable_nnapi) + return _run(remote, tracker, ex, inputs) diff --git a/tests/python/nightly/test_nnapi/test_network.py b/tests/python/nightly/test_nnapi/test_network.py new file mode 100644 index 000000000000..742613c25c75 --- /dev/null +++ b/tests/python/nightly/test_nnapi/test_network.py @@ -0,0 +1,136 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""NNAPI network tests.""" + +from typing import List + +import numpy as np +import onnx +import pytest +from test_nnapi.conftest import remote +from test_nnapi.infrastructure import build_and_run # , build_and_run_vm + +import tvm +from tvm.contrib.download import download_testdata +from tvm.relax.frontend.onnx import from_onnx + + +def _build_and_run_network(remote_obj, tracker, mod, input_data): + """Helper function to build and run a network.""" + + def execute_on_host(mod, inputs): + with tvm.transform.PassContext(opt_level=3): + ex = tvm.relax.build(mod, target="llvm") + dev = tvm.cpu(0) + vm = tvm.relax.VirtualMachine(ex, device=dev) + output = vm["main"](*inputs) + return output.numpy() + + outputs = [] + for nnapi in [True, False]: + if nnapi: + outputs.append( + build_and_run( + remote_obj, + tracker, + mod, + input_data, + enable_nnapi=nnapi, + ) + ) + else: + outputs.append(execute_on_host(mod, input_data)) + return outputs + + +def get_network(name, dtype, input_shape=(1, 3, 224, 224)): + def download_model(model_url, name): + model_path = download_testdata(model_url, name + ".onnx", module="onnx") + onnx_model = onnx.load(model_path) + + shape_dict = {"x": input_shape} + mod = from_onnx(onnx_model, shape_dict) + return mod + + def create_model(name): + if "vgg11" == name: + model_url = "https://github.com/onnx/models/raw/bec48b6a70e5e9042c0badbaafefe4454e072d08/Computer_Vision/vgg11_Opset18_timm/vgg11_Opset18.onnx" + elif "mobilenetv3" == name: + model_url = "https://github.com/onnx/models/raw/bec48b6a70e5e9042c0badbaafefe4454e072d08/Computer_Vision/mobilenetv3_large_100_miil_Opset17_timm/mobilenetv3_large_100_miil_Opset17.onnx" + elif "alexnet" == name: + model_url = "https://github.com/onnx/models/raw/bec48b6a70e5e9042c0badbaafefe4454e072d08/Computer_Vision/alexnet_Opset17_torch_hub/alexnet_Opset17.onnx" + elif "resnet50" == name: + model_url = "https://github.com/onnx/models/raw/bec48b6a70e5e9042c0badbaafefe4454e072d08/Computer_Vision/resnet50_Opset18_timm/resnet50_Opset18.onnx" + elif "resnet34" == name: + model_url = "https://github.com/onnx/models/raw/bec48b6a70e5e9042c0badbaafefe4454e072d08/Computer_Vision/resnet34_Opset18_timm/resnet34_Opset18.onnx" + elif "resnet18" == name: + model_url = "https://github.com/onnx/models/raw/bec48b6a70e5e9042c0badbaafefe4454e072d08/Computer_Vision/resnet18_Opset18_timm/resnet18_Opset18.onnx" + elif "squeezenet" == name: + model_url = "https://github.com/onnx/models/raw/bec48b6a70e5e9042c0badbaafefe4454e072d08/Computer_Vision/squeezenet1_1_Opset18_torch_hub/squeezenet1_1_Opset18.onnx" + elif "vgg16" == name: + model_url = "https://github.com/onnx/models/raw/bec48b6a70e5e9042c0badbaafefe4454e072d08/Computer_Vision/vgg16_Opset18_timm/vgg16_Opset18.onnx" + elif "vgg19" == name: + model_url = "https://github.com/onnx/models/raw/bec48b6a70e5e9042c0badbaafefe4454e072d08/Computer_Vision/vgg19_Opset18_timm/vgg19_Opset18.onnx" + else: + assert False, f"Not supported model {name}" + + return download_model(model_url, name) + + mod = create_model(name) + return mod, {"data": (input_shape, dtype)} + + +@pytest.mark.parametrize( + "name", + [ + "alexnet", + "vgg11", + "vgg16", + "vgg19", + "resnet18", + "resnet34", + "resnet50", + "squeezenet", + "mobilenetv3", + ], +) +@pytest.mark.parametrize( + "dtype", + [ + "float32", + ], +) +@tvm.testing.requires_nnapi +def test_network(name, dtype): + remote_obj, tracker = remote() + print(f"Network evaluating {name} with dtype {dtype}") + np.random.seed(0) + mod, inputs = get_network(name, dtype) + input_data = {} + + for _name, (shape, _dtype) in inputs.items(): + input_data[_name] = np.random.uniform(-1.0, 1.0, shape).astype(_dtype) + + inputs_tvm: List[tvm.nd.NDArray] = [tvm.nd.array(v) for k, v in input_data.items()] + outputs = _build_and_run_network(remote_obj, tracker, mod, inputs_tvm) + nnapi_out = outputs[0] + expected_out = outputs[1] + tvm.testing.assert_allclose(nnapi_out, expected_out, rtol=1e-4, atol=1e-5) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/nightly/test_nnapi/test_ops.py b/tests/python/nightly/test_nnapi/test_ops.py new file mode 100644 index 000000000000..589ff6ee89e7 --- /dev/null +++ b/tests/python/nightly/test_nnapi/test_ops.py @@ -0,0 +1,362 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""NNAPI integration operator tests.""" + +from typing import List + +import numpy as np +import pytest +from test_nnapi.conftest import remote +from test_nnapi.infrastructure import build_and_run + +import tvm +import tvm.script +import tvm.script.relax as R +import tvm.script.tir as T + + +def _build_and_run_network(remote_obj, tracker, mod, input_data): + """Helper function to build and run a network.""" + + def execute_on_host(mod, inputs): + with tvm.transform.PassContext(opt_level=3): + ex = tvm.relax.build(mod, target="llvm") + dev = tvm.cpu(0) + vm = tvm.relax.VirtualMachine(ex, device=dev) + output = vm["main"](*inputs) + return output.numpy() + + outputs = [] + for nnapi in [True, False]: + if nnapi: + outputs.append( + build_and_run( + remote_obj, + tracker, + mod, + input_data, + enable_nnapi=nnapi, + ) + ) + else: + outputs.append(execute_on_host(mod, input_data)) + return outputs + + +@pytest.mark.parametrize( + "op", + [ + R.exp, + R.log, + R.negative, + R.sqrt, + R.rsqrt, + R.floor, + R.nn.relu, + R.nn.softmax, + R.sigmoid, + R.tanh, + R.abs, + ], +) +def test_unary(op, input_shape=(1, 2, 8, 5)): + remote_obj, tracker = remote() + + def create_model() -> tvm.IRModule: + @tvm.script.ir_module + class Module: + @R.function + def main(i0: R.Tensor((1, 2, 8, 5), "float32")) -> R.Tensor((1, 2, 8, 5), "float32"): + with R.dataflow(): + t0 = op(i0) + R.output(t0) + return t0 + + return Module + + mod = create_model() + verify( + remote_obj, + tracker, + mod, + inputs=[np.random.uniform(size=(1, 2, 8, 5)).astype("float32")], + ) + + +@pytest.mark.parametrize( + "op", + [ + R.power, + R.greater, + R.add, + R.multiply, + R.subtract, + R.equal, + R.less, + R.less_equal, + R.not_equal, + R.maximum, + R.minimum, + R.greater_equal, + ], +) +def test_elementwise_binary(op, input_shape=(1, 2, 8, 5)): + remote_obj, tracker = remote() + + def create_model() -> tvm.IRModule: + @tvm.script.ir_module + class Module: + @R.function + def main( + i0: R.Tensor((1, 2, 8, 5), "float32"), + i1: R.Tensor((1, 2, 8, 5), "float32"), + ) -> R.Tensor((1, 2, 8, 5), "float32"): + with R.dataflow(): + t0 = op(i0, i1) + R.output(t0) + return t0 + + return Module + + mod = create_model() + verify( + remote_obj, + tracker, + mod, + inputs=[ + np.random.uniform(size=input_shape).astype("float32"), + np.random.uniform(size=input_shape).astype("float32"), + ], + ) + + +def test_divide(input_shape=(1, 2, 8, 5)): + remote_obj, tracker = remote() + + def create_model(input_shape) -> tvm.IRModule: + @tvm.script.ir_module + class Module: + @R.function + def main( + i0: R.Tensor((1, 2, 8, 5), "float32"), + i1: R.Tensor((1, 2, 8, 5), "float32"), + ) -> R.Tensor((1, 2, 8, 5), "float32"): + with R.dataflow(): + t0 = R.divide(i0, i1) + R.output(t0) + return t0 + + return Module + + mod = create_model(input_shape) + verify( + remote_obj, + tracker, + mod, + inputs=[ + np.random.uniform(size=input_shape).astype("float32"), + np.random.uniform(size=input_shape).astype("float32") + np.ones(input_shape, "float32"), + ], + ) + + +def test_matmul(): + remote_obj, tracker = remote() + + def create_model() -> tvm.IRModule: + @tvm.script.ir_module + class Module: + @R.function + def main( + i0: R.Tensor((5, 3, 4), "float32"), + i1: R.Tensor((5, 4, 8), "float32"), + ) -> R.Tensor((5, 3, 8), "float32"): + with R.dataflow(): + t0 = R.matmul(i0, i1) + R.output(t0) + return t0 + + return Module + + mod = create_model() + verify( + remote_obj, + tracker, + mod, + inputs=[ + np.random.random(size=(5, 3, 4)).astype("float32"), + np.random.random(size=(5, 4, 8)).astype("float32"), + ], + ) + + +def test_permute_dims(): + remote_obj, tracker = remote() + + def create_model() -> tvm.IRModule: + @tvm.script.ir_module + class Module: + @R.function + def main( + i0: R.Tensor((5, 4, 8), "float32"), + ) -> R.Tensor((8, 5, 4), "float32"): + with R.dataflow(): + t0 = R.permute_dims(i0, axes=[2, 0, 1]) + R.output(t0) + return t0 + + return Module + + mod = create_model() + verify( + remote_obj, + tracker, + mod, + inputs=[ + np.random.random(size=(5, 4, 8)).astype("float32"), + ], + ) + + +def test_astype(): + remote_obj, tracker = remote() + + def create_model() -> tvm.IRModule: + @tvm.script.ir_module + class Module: + @R.function + def main( + i0: R.Tensor((8, 10, 15), "float32"), + ) -> R.Tensor((8, 10, 15), "float16"): + with R.dataflow(): + t0: R.Tensor((8, 10, 15), "float16") = R.astype(i0, dtype="float16") + R.output(t0) + return t0 + + return Module + + mod = create_model() + verify( + remote_obj, + tracker, + mod, + inputs=[ + tvm.nd.array(np.random.uniform(size=(8, 10, 15)).astype("float32")), + ], + ) + + +def test_mean(): + remote_obj, tracker = remote() + + def create_model() -> tvm.IRModule: + @tvm.script.ir_module + class Module: + @R.function + def main( + i0: R.Tensor((1, 10, 15), "float32"), + ) -> R.Tensor((1, 10, 1), "float32"): + n = T.int64() + with R.dataflow(): + t0: R.Tensor((1, 10, 15), "float32") = R.mean(i0, axis=[-1], keepdims=True) + R.output(t0) + return t0 + + return Module + + mod = create_model() + verify( + remote_obj, + tracker, + mod, + inputs=[ + tvm.nd.array(np.random.uniform(size=(1, 10, 15)).astype("float32")), + ], + ) + + +def test_conv2d(): + remote_obj, tracker = remote() + + def create_model() -> tvm.IRModule: + @tvm.script.ir_module + class Module: + @R.function + def main( + i0: R.Tensor((1, 3, 224, 224), "float32"), + i1: R.Tensor((64, 3, 3, 3), "float32"), + i2: R.Tensor((1, 64, 1, 1), "float32"), + ): + with R.dataflow(): + t0 = R.nn.conv2d(i0, i1, strides=(1, 1), padding=(1, 1)) + t0 = R.add(i2, t0) + R.output(t0) + return t0 + + return Module + + mod = create_model() + verify( + remote_obj, + tracker, + mod, + inputs=[ + np.random.random(size=(1, 3, 224, 224)).astype("float32"), + np.random.random(size=(64, 3, 3, 3)).astype("float32"), + np.random.random(size=(1, 64, 1, 1)).astype("float32"), + ], + ) + + +def test_max_pool2d(): + remote_obj, tracker = remote() + + def create_model() -> tvm.IRModule: + @tvm.script.ir_module + class Module: + @R.function + def main( + i0: R.Tensor((1, 1, 28, 28), "float32"), + ): + with R.dataflow(): + t0 = R.nn.max_pool2d(i0, pool_size=(1, 1), strides=(1, 1), padding=(0, 0)) + R.output(t0) + return t0 + + return Module + + mod = create_model() + verify( + remote_obj, + tracker, + mod, + inputs=[ + np.random.random(size=(1, 1, 28, 28)).astype("float32"), + ], + ) + + +def verify(remote_obj, tracker, mod, inputs): + inputs_tvm: List[tvm.nd.NDArray] = [tvm.nd.array(v) for v in inputs] + outputs = _build_and_run_network(remote_obj, tracker, mod, inputs_tvm) + nnapi_out = outputs[0] + expected_out = outputs[1] + tvm.testing.assert_allclose(nnapi_out, expected_out, rtol=1e-4, atol=1e-5) + + +if __name__ == "__main__": + tvm.testing.main()