Skip to content

Commit b5c1ed4

Browse files
mengshyukotatsuyakiHMZ
committed
[BYOC][NNAPI] This PR intorduce NNAPI to TVM
This PR introduces a new BYOC backend for Android Neural Networks API (NNAPI), enabling execution of neural networks on custom accelerators. This feature adds a new codegen and runtime for NNAPI, supporting operations such as element-wise ops, nn.dense, and nn.conv2d for CNN model with static shape. Co-authored-by: Ming-Long Huang <[email protected]> Co-authored-by: HMZ <[email protected]>
1 parent 11198f6 commit b5c1ed4

File tree

17 files changed

+2753
-0
lines changed

17 files changed

+2753
-0
lines changed

CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,8 @@ tvm_option(USE_ARM_COMPUTE_LIB "Build with Arm Compute Library" OFF)
125125
tvm_option(USE_ARM_COMPUTE_LIB_GRAPH_EXECUTOR "Build with Arm Compute Library graph executor" OFF)
126126
tvm_option(USE_TENSORRT_CODEGEN "Build with TensorRT Codegen support" OFF)
127127
tvm_option(USE_TENSORRT_RUNTIME "Build with TensorRT runtime" OFF)
128+
tvm_option(USE_NNAPI_CODEGEN "Build with NNAPI Codegen support" OFF)
129+
tvm_option(USE_NNAPI_RUNTIME "Build with NNAPI runtime" OFF)
128130
tvm_option(USE_RUST_EXT "Build with Rust based compiler extensions, STATIC, DYNAMIC, or OFF" OFF)
129131
tvm_option(USE_VITIS_AI "Build with VITIS-AI Codegen support" OFF)
130132
tvm_option(SUMMARIZE "Print CMake option summary after configuring" OFF)
@@ -602,6 +604,7 @@ include(cmake/modules/contrib/BNNS.cmake)
602604
include(cmake/modules/contrib/ONNX.cmake)
603605
include(cmake/modules/contrib/ArmComputeLib.cmake)
604606
include(cmake/modules/contrib/TensorRT.cmake)
607+
include(cmake/modules/contrib/NNAPI.cmake)
605608
include(cmake/modules/contrib/VitisAI.cmake)
606609
include(cmake/modules/contrib/Verilator.cmake)
607610
include(cmake/modules/contrib/UMA.cmake)

cmake/modules/LibInfo.cmake

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ function(add_lib_info src_file)
144144
TVM_INFO_USE_MSC="${USE_MSC}"
145145
TVM_INFO_USE_CCACHE="${USE_CCACHE}"
146146
TVM_INFO_USE_NVSHMEM="${USE_NVSHMEM}"
147+
TVM_INFO_USE_NNAPI_CODEGEN="${USE_NNAPI_CODEGEN}"
148+
TVM_INFO_USE_NNAPI_RUNTIME="${USE_NNAPI_RUNTIME}"
147149
TVM_INFO_BACKTRACE_ON_SEGFAULT="${BACKTRACE_ON_SEGFAULT}"
148150
)
149151

cmake/modules/contrib/NNAPI.cmake

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
# NNAPI Codegen
19+
if(USE_NNAPI_CODEGEN)
20+
message(STATUS "Build with NNAPI codegen")
21+
22+
tvm_file_glob(GLOB COMPILER_NNAPI_SRCS src/relax/backend/contrib/nnapi/*.cc)
23+
tvm_file_glob(GLOB RUNTIME_NNAPI_SRCS src/runtime/contrib/nnapi/*.cc)
24+
list(APPEND COMPILER_SRCS ${COMPILER_NNAPI_SRCS})
25+
if(NOT USE_NNAPI_RUNTIME)
26+
list(APPEND COMPILER_SRCS ${RUNTIME_NNAPI_SRCS})
27+
endif()
28+
endif()
29+
30+
# NNAPI Runtime
31+
if(USE_NNAPI_RUNTIME)
32+
message(STATUS "Build with NNAPI runtime")
33+
34+
tvm_file_glob(GLOB RUNTIME_NNAPI_SRCS src/runtime/contrib/nnapi/*.cc)
35+
list(APPEND RUNTIME_SRCS ${RUNTIME_NNAPI_SRCS})
36+
list(APPEND TVM_RUNTIME_LINKER_LIBS neuralnetworks log)
37+
38+
add_definitions(-DTVM_GRAPH_EXECUTOR_NNAPI)
39+
endif()
Lines changed: 322 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,322 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""Pattern table for NNAPI backend"""
19+
from typing import (
20+
Mapping,
21+
Optional,
22+
Tuple,
23+
List,
24+
)
25+
from tvm.ir import IRModule
26+
from tvm.relax.transform import FuseOpsByPattern, MergeCompositeFunctions
27+
from tvm.relax.dpl.pattern import (
28+
DFPattern,
29+
wildcard,
30+
is_op,
31+
)
32+
33+
from ..pattern_registry import get_patterns_with_prefix, register_patterns
34+
35+
36+
def elementwise_binary_patterns() -> List[Tuple[str, DFPattern, Mapping[str, DFPattern]]]:
37+
"""
38+
Returns a list of tuples representing elementwise binary operation patterns mapped
39+
between NNAPI and Relax frameworks.
40+
"""
41+
def _elementwise_binary_pattern(
42+
pattern_name: str,
43+
op_name: str,
44+
) -> Tuple[str, DFPattern, Mapping[str, DFPattern]]:
45+
input0 = wildcard()
46+
input1 = wildcard()
47+
48+
pattern = is_op(op_name)(input0, input1)
49+
50+
return (pattern_name, pattern, {})
51+
52+
return [
53+
_elementwise_binary_pattern("nnapi.add", "relax.add"),
54+
_elementwise_binary_pattern("nnapi.mul", "relax.multiply"),
55+
_elementwise_binary_pattern("nnapi.div", "relax.divide"),
56+
_elementwise_binary_pattern("nnapi.sub", "relax.subtract"),
57+
_elementwise_binary_pattern("nnapi.pow", "relax.power"),
58+
_elementwise_binary_pattern("nnapi.equal", "relax.equal"),
59+
_elementwise_binary_pattern("nnapi.greater", "relax.greater"),
60+
_elementwise_binary_pattern("nnapi.greater_equal", "relax.greater_equal"),
61+
_elementwise_binary_pattern("nnapi.less", "relax.less"),
62+
_elementwise_binary_pattern("nnapi.less_equal", "relax.less_equal"),
63+
_elementwise_binary_pattern("nnapi.not_equal", "relax.not_equal"),
64+
_elementwise_binary_pattern("nnapi.maximum", "relax.maximum"),
65+
_elementwise_binary_pattern("nnapi.minimum", "relax.minimum"),
66+
]
67+
68+
69+
def unary_patterns() -> List[Tuple[str, DFPattern, Mapping[str, DFPattern]]]:
70+
"""
71+
Returns a list of tuples representing unary operation patterns mapped
72+
between NNAPI and Relax frameworks.
73+
"""
74+
def _unary_pattern(
75+
pattern_name: str, op_name: str
76+
) -> Tuple[str, DFPattern, Mapping[str, DFPattern]]:
77+
input0 = wildcard()
78+
pattern = is_op(op_name)(input0)
79+
return (pattern_name, pattern, {})
80+
81+
return [
82+
_unary_pattern("nnapi.floor", "relax.floor"),
83+
_unary_pattern("nnapi.relu", "relax.nn.relu"),
84+
_unary_pattern("nnapi.logistic", "relax.sigmoid"),
85+
_unary_pattern("nnapi.softmax", "relax.nn.softmax"),
86+
_unary_pattern("nnapi.tanh", "relax.tanh"),
87+
_unary_pattern("nnapi.abs", "relax.abs"),
88+
_unary_pattern("nnapi.exp", "relax.exp"),
89+
_unary_pattern("nnapi.log", "relax.log"),
90+
_unary_pattern("nnapi.neg", "relax.negative"),
91+
_unary_pattern("nnapi.cast", "relax.astype"),
92+
_unary_pattern("nnapi.sqrt", "relax.sqrt"),
93+
_unary_pattern("nnapi.rsqrt", "relax.rsqrt"),
94+
]
95+
96+
97+
def matmul_pattern() -> Tuple[str, DFPattern, Mapping[str, DFPattern]]:
98+
"""
99+
Returns a tuple representing matmul operation patterns mapped
100+
between NNAPI and Relax frameworks.
101+
"""
102+
input0 = wildcard()
103+
input1 = wildcard()
104+
pattern = is_op("relax.matmul")(input0, input1)
105+
return ("nnapi.batch_matmul", pattern, {})
106+
107+
108+
def permute_dims_pattern() -> Tuple[str, DFPattern, Mapping[str, DFPattern]]:
109+
"""
110+
Returns a tuple representing permute operation patterns mapped
111+
between NNAPI and Relax frameworks.
112+
"""
113+
input0 = wildcard()
114+
pattern = is_op("relax.permute_dims")(input0)
115+
return ("nnapi.transpose", pattern, {})
116+
117+
118+
def astype_pattern() -> Tuple[str, DFPattern, Mapping[str, DFPattern]]:
119+
"""
120+
Returns a tuple representing astype operation patterns mapped
121+
between NNAPI and Relax frameworks.
122+
"""
123+
input0 = wildcard().has_dtype("float16") | wildcard().has_dtype("float32")
124+
pattern = is_op("relax.astype")(input0).has_dtype("float16") | is_op("relax.astype")(
125+
input0
126+
).has_dtype("float32")
127+
128+
return ("nnapi.cast", pattern, {})
129+
130+
131+
def mean_pattern() -> Tuple[str, DFPattern, Mapping[str, DFPattern]]:
132+
"""
133+
Returns a tuple representing mean operation patterns mapped
134+
between NNAPI and Relax frameworks.
135+
"""
136+
input0 = wildcard()
137+
pattern = is_op("relax.mean")(input0)
138+
139+
return ("nnapi.mean", pattern, {})
140+
141+
142+
def conv2d_pattern() -> Tuple[str, DFPattern, Mapping[str, DFPattern]]:
143+
"""
144+
Returns a tuple representing conv2d operation patterns mapped
145+
between NNAPI and Relax frameworks.
146+
"""
147+
input0 = wildcard()
148+
input1 = wildcard()
149+
input2 = wildcard()
150+
conv = is_op("relax.nn.conv2d")(input0, input1)
151+
pattern = is_op("relax.add")(conv, input2)
152+
return ("nnapi.conv2d", pattern, {})
153+
154+
155+
def max_pool2d_pattern() -> Tuple[str, DFPattern, Mapping[str, DFPattern]]:
156+
"""
157+
Returns a tuple representing max_pool2d operation patterns mapped
158+
between NNAPI and Relax frameworks.
159+
"""
160+
input0 = wildcard()
161+
pattern = is_op("relax.nn.max_pool2d")(input0)
162+
return ("nnapi.max_pool_2d", pattern, {})
163+
164+
165+
register_patterns(
166+
[
167+
*elementwise_binary_patterns(),
168+
*unary_patterns(),
169+
matmul_pattern(),
170+
permute_dims_pattern(),
171+
astype_pattern(),
172+
mean_pattern(),
173+
conv2d_pattern(),
174+
max_pool2d_pattern(),
175+
]
176+
)
177+
178+
179+
def min_feature_level(pattern_name: str) -> int:
180+
"""
181+
Returns the minimum feature level required to support a given NNAPI operation pattern.
182+
183+
Args:
184+
pattern_name (str): The name of the NNAPI operation pattern
185+
(e.g., "nnapi.add", "nnapi.conv2d").
186+
187+
Returns:
188+
int: The minimum feature level for the specified pattern, or 1 if the pattern is not found.
189+
"""
190+
191+
levels = {
192+
"nnapi.add": 1,
193+
"nnapi.average_pool_2d": 1,
194+
"nnapi.concatenation": 1,
195+
"nnapi.conv2d": 1,
196+
"nnapi.depthwise_conv_2d": 1,
197+
"nnapi.depth_to_space": 1,
198+
"nnapi.dequantize": 1,
199+
"nnapi.embedding_lookup": 1,
200+
"nnapi.floor": 1,
201+
"nnapi.fully_connected": 1,
202+
"nnapi.hashtable_lookup": 1,
203+
"nnapi.l2_normalization": 1,
204+
"nnapi.l2_pool_2d": 1,
205+
"nnapi.local_response_normalization": 1,
206+
"nnapi.logistic": 1,
207+
"nnapi.lsh_projection": 1,
208+
"nnapi.lstm": 1,
209+
"nnapi.max_pool_2d": 1,
210+
"nnapi.mul": 1,
211+
"nnapi.relu": 1,
212+
"nnapi.relu1": 1,
213+
"nnapi.relu6": 1,
214+
"nnapi.reshape": 1,
215+
"nnapi.resize_bilinear": 1,
216+
"nnapi.rnn": 1,
217+
"nnapi.softmax": 1,
218+
"nnapi.space_to_depth": 1,
219+
"nnapi.svdf": 1,
220+
"nnapi.tanh": 1,
221+
"nnapi.batch_to_space_nd": 2,
222+
"nnapi.div": 2,
223+
"nnapi.mean": 2,
224+
"nnapi.pad": 2,
225+
"nnapi.space_to_batch_nd": 2,
226+
"nnapi.squeeze": 2,
227+
"nnapi.strided_slice": 2,
228+
"nnapi.sub": 2,
229+
"nnapi.transpose": 2,
230+
"nnapi.abs": 3,
231+
"nnapi.argmax": 3,
232+
"nnapi.argmin": 3,
233+
"nnapi.axis_aligned_bbox_transform": 3,
234+
"nnapi.bidirectional_sequence_lstm": 3,
235+
"nnapi.bidirectional_sequence_rnn": 3,
236+
"nnapi.box_with_nms_limit": 3,
237+
"nnapi.cast": 3,
238+
"nnapi.channel_shuffle": 3,
239+
"nnapi.detection_postprocessing": 3,
240+
"nnapi.equal": 3,
241+
"nnapi.exp": 3,
242+
"nnapi.expand_dims": 3,
243+
"nnapi.gather": 3,
244+
"nnapi.generate_proposals": 3,
245+
"nnapi.greater": 3,
246+
"nnapi.greater_equal": 3,
247+
"nnapi.grouped_conv_2d": 3,
248+
"nnapi.heatmap_max_keypoint": 3,
249+
"nnapi.instance_normalization": 3,
250+
"nnapi.less": 3,
251+
"nnapi.less_equal": 3,
252+
"nnapi.log": 3,
253+
"nnapi.logical_and": 3,
254+
"nnapi.logical_not": 3,
255+
"nnapi.logical_or": 3,
256+
"nnapi.log_softmax": 3,
257+
"nnapi.maximum": 3,
258+
"nnapi.minimum": 3,
259+
"nnapi.neg": 3,
260+
"nnapi.not_equal": 3,
261+
"nnapi.pad_v2": 3,
262+
"nnapi.pow": 3,
263+
"nnapi.prelu": 3,
264+
"nnapi.quantize": 3,
265+
"nnapi.quantized_16bit_lstm": 3,
266+
"nnapi.random_multinomial": 3,
267+
"nnapi.reduce_all": 3,
268+
"nnapi.reduce_any": 3,
269+
"nnapi.reduce_max": 3,
270+
"nnapi.reduce_min": 3,
271+
"nnapi.reduce_prod": 3,
272+
"nnapi.reduce_sum": 3,
273+
"nnapi.roi_align": 3,
274+
"nnapi.roi_pooling": 3,
275+
"nnapi.rsqrt": 3,
276+
"nnapi.select": 3,
277+
"nnapi.sin": 3,
278+
"nnapi.slice": 3,
279+
"nnapi.split": 3,
280+
"nnapi.sqrt": 3,
281+
"nnapi.tile": 3,
282+
"nnapi.topk_v2": 3,
283+
"nnapi.transpose_conv_2d": 3,
284+
"nnapi.unidirectional_sequence_lstm": 3,
285+
"nnapi.unidirectional_sequence_rnn": 3,
286+
"nnapi.resize_nearest_neighbor": 3,
287+
"nnapi.quantized_lstm": 4,
288+
"nnapi.if": 4,
289+
"nnapi.while": 4,
290+
"nnapi.elu": 4,
291+
"nnapi.hard_swish": 4,
292+
"nnapi.fill": 4,
293+
"nnapi.rank": 4,
294+
"nnapi.batch_matmul": 6,
295+
"nnapi.pack": 6,
296+
"nnapi.mirror_pad": 7,
297+
"nnapi.reverse": 7,
298+
}
299+
return levels[pattern_name]
300+
301+
302+
def partition_for_nnapi(mod: IRModule, feature_level: Optional[int] = None) -> IRModule:
303+
"""Partition the graph greedily offloading supported operators to NNAPI.
304+
305+
Parameters
306+
----------
307+
mod : tvm.ir.IRModule
308+
The module to run passes on.
309+
feature_level : Optional[int]
310+
The maximum NNAPI feature level.
311+
312+
Returns
313+
-------
314+
mod : tvm.ir.IRModule
315+
Annotated and partitioned module.
316+
"""
317+
patterns = get_patterns_with_prefix("nnapi")
318+
if feature_level is not None:
319+
patterns = [pat for pat in patterns if feature_level >= min_feature_level(pat.name)]
320+
mod = FuseOpsByPattern(patterns, bind_constants=False, annotate_codegen=False)(mod)
321+
mod = MergeCompositeFunctions()(mod)
322+
return mod

0 commit comments

Comments
 (0)