Skip to content

Commit 9ada371

Browse files
authored
[microNPU] Add the infrastructure for lookup table and TANH (#9547)
Some activation functions like TANH and SIGMOID are implemented by calculating the values based on the QNN parameters and recording the values into a lookup table (LUT). This patch adds the LUT functionality alongside with the TANH activation function and the tests.
1 parent 3047709 commit 9ada371

21 files changed

+832
-42
lines changed

python/tvm/relay/backend/contrib/ethosu/codegen.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,115 @@
2222
from tvm.relay.backend.contrib.ethosu.legalize import LegalizeEthosU
2323
from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator
2424
from tvm.relay.backend.contrib.ethosu import util
25+
from tvm.relay.expr_functor import ExprMutator
26+
from tvm.ir.transform import Pass
27+
28+
# pylint: disable=unused-import
29+
from tvm.relay.backend.contrib.ethosu.op import op_attrs
30+
from tvm.relay.backend.contrib.ethosu import op
31+
32+
33+
class OptimizeLUTs(ExprMutator):
34+
"""A pass to merge an identity operator with a LUT based activation function with
35+
a preceding operator provided that operator can do a table lookup for the activation
36+
in the hardware"""
37+
38+
def __init__(self):
39+
super().__init__()
40+
self.lut_ops = {
41+
"contrib.ethosu.conv2d": op.ethosu_conv2d,
42+
"contrib.ethosu.depthwise_conv2d": op.ethosu_depthwise_conv2d,
43+
"contrib.ethosu.pooling": op.ethosu_pooling,
44+
}
45+
46+
def create_op_with_lut(self, call):
47+
"""Extract the parameters and attributes from the NPU operator and create
48+
a new operator with LUT.
49+
50+
Parameters
51+
----------
52+
call : tvm.relay.expr.Call
53+
The current call node being visited.
54+
55+
Returns
56+
-------
57+
tvm.relay.expr.Call
58+
The new operator with LUT.
59+
"""
60+
identity = call
61+
ethosu_op = call.args[0]
62+
lut = identity.args[1]
63+
activation = identity.attrs.activation
64+
65+
new_attrs = dict(ethosu_op.attrs)
66+
new_attrs["activation"] = activation
67+
68+
# Assume that LUT is always the last argument
69+
new_args = ethosu_op.args[:-1] + [lut]
70+
assert ethosu_op.op.name in self.lut_ops.keys()
71+
72+
return self.lut_ops[ethosu_op.op.name](*new_args, **new_attrs)
73+
74+
def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call:
75+
"""Recursively visit call nodes in the input graph and if an ethosu.identity
76+
operator with LUT is found and the preceding operator has a LUT attribute, create
77+
a new NPU operator.
78+
79+
Parameters
80+
----------
81+
call : tvm.relay.expr.Call
82+
The current call node being visited.
83+
84+
Returns
85+
-------
86+
tvm.relay.expr.Call
87+
The input call node in the case the current call node does
88+
not refer to an Op. Else, a new call node with a new operator.
89+
"""
90+
new_call = call
91+
lut_activations = ["TANH", "LUT"]
92+
93+
if isinstance(call.op, tvm.ir.Op) and isinstance(call.args[0], tvm.relay.expr.Call):
94+
producer_op = call.args[0]
95+
# Check if the producer can do a LUT operation
96+
if (
97+
producer_op.op.name in self.lut_ops.keys()
98+
and call.op.name == "contrib.ethosu.identity"
99+
and call.attrs.activation in lut_activations
100+
):
101+
# Check the producer doesn't already have a LUT
102+
has_lut = producer_op.attrs.activation in lut_activations
103+
if not has_lut:
104+
new_call = self.create_op_with_lut(call)
105+
106+
new_call = super().visit_call(new_call)
107+
108+
return new_call
109+
110+
111+
@relay.transform.function_pass(opt_level=1, name="LUTsOptimizer")
112+
class LUTsOptimizer(Pass):
113+
"""Register LUTsOptimizer as a relay pass."""
114+
115+
def transform_function(
116+
self, func: tvm.relay.function.Function, mod: tvm.IRModule, _
117+
) -> tvm.IRModule:
118+
"""Visit relay nodes in the given module.
119+
120+
Parameters
121+
----------
122+
func : tvm.relay.function.Function
123+
The function to apply the optimization pass for multiple LUTs to.
124+
mod : tvm.IRModule
125+
The module to apply the optimization pass for multiple LUTs to.
126+
127+
Returns
128+
-------
129+
mod : tvm.IRModule
130+
New module with optimized LUTs.
131+
"""
132+
assert len(mod.functions.items()) == 1, "Module can only contain one function."
133+
return OptimizeLUTs().visit(func)
25134

26135

27136
@tvm._ffi.register_func("relay.ext.ethos-u")
@@ -74,6 +183,7 @@ def _compile(ext_func):
74183
mod = tvm.IRModule()
75184
mod["main"] = ext_func
76185
mod = LegalizeEthosU()(mod)
186+
mod = LUTsOptimizer()(mod)
77187
mod = relay.transform.InferType()(mod)
78188
# We are currently using copy_constants scheduler In the long run,
79189
# this should be a single intelligent and a composite scheduler

python/tvm/relay/backend/contrib/ethosu/legalize.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# pylint: disable=invalid-name, unused-argument, import-outside-toplevel, no-value-for-parameter
1818
"""A set of passes to legalize some of operations for the NPU"""
1919
from typing import List, Type
20+
import math
2021

2122
import numpy as np # type: ignore
2223

@@ -31,6 +32,7 @@
3132
from tvm.relay.backend.contrib.ethosu import op as ethosu_ops # type: ignore
3233
from tvm.relay.backend.contrib.ethosu.errors import UnsupportedLayout # type: ignore
3334
from tvm.relay.backend.contrib.ethosu import vela_api
35+
from tvm.relay.backend.contrib.ethosu import util
3436
from tvm.relay.op.contrib import ethosu as ethosu_patterns # type: ignore
3537

3638

@@ -123,6 +125,75 @@ def __call__(self, *args, **kwargs):
123125
pass
124126

125127

128+
def find_tanh_values(ifm_scale, ifm_zp, ofm_scale, ofm_zp):
129+
"""Method to calculate the values of the tanh lookup table"""
130+
lut_values = list()
131+
# Only int8 is currently supported
132+
dtype = np.int8
133+
qmin, qmax = np.iinfo(dtype).min, np.iinfo(dtype).max
134+
for x in range(qmin, qmax + 1):
135+
x_real = ifm_scale * (x - ifm_zp)
136+
out_real = math.tanh(x_real)
137+
lut_result = int(util.round_away_zero(ofm_zp + out_real / ofm_scale))
138+
lut_result = min(qmax, max(qmin, lut_result))
139+
lut_values.append(lut_result)
140+
141+
return lut_values
142+
143+
144+
class TanhRewriter(DFPatternCallback):
145+
"""This pass adds tanh as a LUT to the identity operator"""
146+
147+
def __init__(self):
148+
super().__init__(require_type=True, rewrite_once=True)
149+
self.pattern = (
150+
wildcard().has_attr({"Composite": ethosu_patterns.TanhParams.composite_name})
151+
)(wildcard())
152+
153+
def callback(self, pre, post, node_map):
154+
id_input = post.args[0]
155+
156+
quantize_args = post.op.body.args
157+
output_scale = float(quantize_args[1].data.asnumpy())
158+
output_zp = int(quantize_args[2].data.asnumpy())
159+
160+
dequantize_args = quantize_args[0].args[0].args
161+
input_scale = float(dequantize_args[1].data.asnumpy())
162+
input_zp = int(dequantize_args[2].data.asnumpy())
163+
164+
lut_values = find_tanh_values(input_scale, input_zp, output_scale, output_zp)
165+
lut = relay.const(lut_values, dtype="uint8")
166+
167+
# We baked the requantization into the LUT, so we don't requantize the identity operator
168+
identity = ethosu_ops.ethosu_identity(
169+
ifm=id_input,
170+
lut=lut,
171+
ifm_scale=input_scale,
172+
ifm_zero_point=input_zp,
173+
ofm_scale=input_scale,
174+
ofm_zero_point=input_zp,
175+
activation="TANH",
176+
)
177+
178+
return identity
179+
180+
181+
@ir.transform.module_pass(opt_level=1)
182+
class LegalizeTanh:
183+
"""This is the pass that wraps TanhRewriter"""
184+
185+
def transform_module(
186+
self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
187+
) -> tvm.ir.IRModule:
188+
for global_var, func in mod.functions.items():
189+
func = rewrite(TanhRewriter(), func)
190+
mod.update_func(global_var, func)
191+
return mod
192+
193+
def __call__(self, *args, **kwargs):
194+
pass
195+
196+
126197
class Conv2DRewriter(DFPatternCallback):
127198
"""Convert conv2d related composite functions into ethosu_conv2d operators"""
128199

@@ -915,6 +986,7 @@ def transform_module(
915986
mod = LegalizeMax()(mod)
916987
mod = LegalizeShl()(mod)
917988
mod = LegalizeAbs()(mod)
989+
mod = LegalizeTanh()(mod)
918990
mod = LegalizeReshape()(mod)
919991
mod = LegalizeStridedSlice()(mod)
920992
mod = LegalizeNoOps()(mod)
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""The attributes node used for Arm(R) Ethos(TM)-U NPU Relay operators."""
18+
from tvm.ir import Attrs
19+
import tvm._ffi
20+
21+
22+
@tvm._ffi.register_object("relay.attrs.EthosuConv2DAttrs")
23+
class EthosuConv2DAttrs(Attrs):
24+
"""Attributes for contrib.ethosu.conv2d."""
25+
26+
27+
@tvm._ffi.register_object("relay.attrs.EthosuIdentityAttrs")
28+
class EthosuIdentityAttrs(Attrs):
29+
"""Attributes for contrib.ethosu.identity."""
30+
31+
32+
@tvm._ffi.register_object("relay.attrs.EthosuDepthwiseConv2DAttrs")
33+
class EthosuDepthwiseConv2DAttrs(Attrs):
34+
"""Attributes for contrib.ethosu.depthwise_conv2d."""
35+
36+
37+
@tvm._ffi.register_object("relay.attrs.EthosuPoolingAttrs")
38+
class EthosuPooling2DAttrs(Attrs):
39+
"""Attributes for contrib.ethosu.pooling."""

python/tvm/relay/backend/contrib/ethosu/te/convolution.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,13 @@ def conv2d_compute(
140140
"dilation_w": dilation_w,
141141
}
142142

143+
# This is a trick to insert the LUT tensor into the TE graph if LUT is present
144+
lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if activation in ("TANH", "LUT") else 0
145+
146+
# Add the LUT tensor to the attributes to be able to later tell which tensor is the LUT
147+
if activation in ("TANH", "LUT"):
148+
conv2d_attrs["lut"] = lut
149+
143150
conv = te.compute(
144151
(1, ofm_height, ofm_width, ofm_channels),
145152
lambda nn, hh, ww, cc: te.sum(
@@ -148,7 +155,7 @@ def conv2d_compute(
148155
).astype(ifm.dtype)
149156
* weight[cc, rh, rw, rc].astype(ifm.dtype)
150157
# This is a trick to load 10 elements of the scale_bias at once, not accurate maths
151-
+ (scale_bias[cc, 0] * scale_bias[cc, 9]).astype(ifm.dtype),
158+
+ (scale_bias[cc, 0] * scale_bias[cc, 9] + lut_expr).astype(ifm.dtype),
152159
axis=[rh, rw, rc],
153160
),
154161
name="ethosu_conv2d",

python/tvm/relay/backend/contrib/ethosu/te/depthwise.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,13 @@ def depthwise_conv2d_compute(
136136
"dilation_w": dilation_w,
137137
}
138138

139+
# This is a trick to insert the LUT tensor into the TE graph if LUT is present
140+
lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if activation in ("TANH", "LUT") else 0
141+
142+
# Add the LUT tensor to the attributes to be able to later tell which tensor is the LUT
143+
if activation in ("TANH", "LUT"):
144+
depthwise_conv2d_attrs["lut"] = lut
145+
139146
depthwise = te.compute(
140147
(1, ofm_height, ofm_width, channels),
141148
lambda nn, hh, ww, cc: te.sum(
@@ -144,7 +151,7 @@ def depthwise_conv2d_compute(
144151
).astype(ifm.dtype)
145152
* weight[cc, rh, rw, 0].astype(ifm.dtype)
146153
# This is a trick to load 10 elements of the scale_bias at once, not accurate maths
147-
+ (scale_bias[cc, 0] * scale_bias[cc, 9]).astype(ifm.dtype),
154+
+ (scale_bias[cc, 0] * scale_bias[cc, 9] + lut_expr).astype(ifm.dtype),
148155
axis=[rh, rw],
149156
),
150157
name="ethosu_depthwise_conv2d",

python/tvm/relay/backend/contrib/ethosu/te/identity.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,21 @@ def identity_compute(
5858
The Output Feature Map tensor.
5959
6060
"""
61-
6261
dmaed_ifm = read_compute(ifm, ifm_zero_point, ifm_scale)
62+
id_attrs = {"op": "ethosu_identity", "activation": activation}
63+
64+
# This is a trick to insert the LUT tensor into the TE graph if LUT is present
65+
lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if activation in ("TANH", "LUT") else 0
66+
67+
# Add the LUT tensor to the attributes to be able to later tell which tensor is the LUT
68+
if activation in ("TANH", "LUT"):
69+
id_attrs["lut"] = lut
6370

6471
identity = te.compute(
6572
ifm.shape,
66-
lambda *i: dmaed_ifm(*i).astype(ifm.dtype),
73+
lambda *i: (dmaed_ifm(*i) + lut_expr).astype(ifm.dtype),
6774
name="ethosu_identity",
68-
attrs={"op": "ethosu_identity", "activation": activation},
75+
attrs=id_attrs,
6976
)
7077

7178
dmaed_ofm = write_compute(identity, ofm_zero_point, ofm_scale)

python/tvm/relay/backend/contrib/ethosu/te/pooling.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,19 @@ def pooling_compute(
123123
"upscale": upscale,
124124
}
125125

126+
# This is a trick to insert the LUT tensor into the TE graph if LUT is present
127+
lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if activation in ("TANH", "LUT") else 0
128+
129+
# Add the LUT tensor to the attributes to be able to later tell which tensor is the LUT
130+
if activation in ("TANH", "LUT"):
131+
pooling_attrs["lut"] = lut
132+
126133
pooling = te.compute(
127134
(1, ofm_height, ofm_width, ofm_channels),
128135
lambda nn, hh, ww, cc: te.max(
129-
dmaed_ifm(nn, hh * stride_h + rh, ww * stride_w + rw, cc).astype(ifm.dtype),
136+
(dmaed_ifm(nn, hh * stride_h + rh, ww * stride_w + rw, cc) + lut_expr).astype(
137+
ifm.dtype
138+
),
130139
axis=[rh, rw],
131140
),
132141
name="ethosu_pooling",

python/tvm/relay/backend/contrib/ethosu/tir/convolution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def get_conv2d_params(stmt, producers, consumers):
5353
rh = inner
5454
rw = rh.body
5555
rc = rw.body
56-
# loads = [output, input, weights, scale_bias, scale_bias]
56+
# loads = [output, input, weights, scale_bias, scale_bias, LUT, LUT]
5757
loads = get_loads(rc.body)
5858
# stores = [output]
5959
stores = get_stores(rc.body)

python/tvm/relay/backend/contrib/ethosu/tir/identity.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from typing import Dict, Tuple
2020
import tvm
2121
from .spec import SerialKernel, SerialActivation, SerialPooling, SerialPadding, SerialFeatureMap
22-
from .utils import get_op_attrs, get_base_address, get_strides
22+
from .utils import get_op_attrs, get_base_address, get_strides, get_loads
2323

2424

2525
def _get_feature_map(stmt: tvm.tir.AttrStmt, fm_type: str) -> Tuple[SerialFeatureMap, tvm.tir.Var]:
@@ -123,7 +123,10 @@ def get_identity_params(
123123
while hasattr(stmt, "body"):
124124
stmt = stmt.body
125125

126-
input_pointer = stmt.value.buffer_var
126+
# loads = [input, LUT, LUT]
127+
loads = get_loads(stmt)
128+
129+
input_pointer = loads[0].buffer_var
127130
output_pointer = stmt.buffer_var
128131

129132
read = producers[input_pointer]

0 commit comments

Comments
 (0)