Skip to content

Commit cff4568

Browse files
[microNPU] Merge LUT activation with binary elementwise operation (#13935)
Add binary elementwise operator to OptimizeLUTs pass to merge LUT activation with elementwise operation.
1 parent 5cf3405 commit cff4568

File tree

6 files changed

+80
-25
lines changed

6 files changed

+80
-25
lines changed

python/tvm/relay/backend/contrib/ethosu/codegen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def __init__(self):
5151
"contrib.ethosu.conv2d": op.ethosu_conv2d,
5252
"contrib.ethosu.depthwise_conv2d": op.ethosu_depthwise_conv2d,
5353
"contrib.ethosu.pooling": op.ethosu_pooling,
54+
"contrib.ethosu.binary_elementwise": op.ethosu_binary_elementwise,
5455
}
5556

5657
def create_op_with_lut(self, call):

python/tvm/relay/backend/contrib/ethosu/te/binary_elementwise.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,14 @@ def binary_elementwise_compute(
178178
}
179179
broadcast = [value == 1 for value in dmaed_ifm2.shape]
180180

181+
has_lut = activation in ("TANH", "LUT", "SIGMOID")
182+
# This is a trick to insert the LUT tensor into the TE graph if LUT is present
183+
lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if has_lut else 0
184+
185+
# Add the LUT tensor to the attributes to be able to later tell which tensor is the LUT
186+
if has_lut:
187+
binary_elementwise_attrs["lut"] = lut
188+
181189
if reversed_operands:
182190
binary_elementwise = te.compute(
183191
(1, ofm_height, ofm_width, ifm_channels),
@@ -188,7 +196,7 @@ def binary_elementwise_compute(
188196
0 if broadcast[2] else ww,
189197
0 if broadcast[3] else cc,
190198
).astype(ifm.dtype),
191-
dmaed_ifm(nn, hh, ww, cc).astype(ifm.dtype),
199+
dmaed_ifm(nn, hh, ww, cc).astype(ifm.dtype) + lut_expr,
192200
).astype(ofm_dtype),
193201
name="ethosu_binary_elementwise",
194202
attrs=binary_elementwise_attrs,
@@ -203,7 +211,8 @@ def binary_elementwise_compute(
203211
0 if broadcast[1] else hh,
204212
0 if broadcast[2] else ww,
205213
0 if broadcast[3] else cc,
206-
).astype(ifm.dtype),
214+
).astype(ifm.dtype)
215+
+ lut_expr,
207216
).astype(ofm_dtype),
208217
name="ethosu_binary_elementwise",
209218
attrs=binary_elementwise_attrs,

python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,30 +18,12 @@
1818
"""Extract information from the binary_elementwise operators in TIR."""
1919
from typing import Tuple
2020
import tvm
21-
from .utils import get_outer_loops, get_op_attrs
21+
from .utils import get_outer_loops, get_op_attrs, get_loads
2222
from .dma import get_ifm_params, get_ofm_params
2323
from .spec import SerialActivation, SerialBinaryElementwise, SerialRescaleConfig
2424
from .producers_consumers import ProducersConsumers
2525

2626

27-
def ignore_cast(tir_load: tvm.tir.expr.Load) -> tvm.tir.Var:
28-
"""When the datatype of the ifm, ifm2 and ofm do not match,
29-
casts are inserted in TE to handle the difference in these types.
30-
Since TIR is not directly run on the NPU we can simply ignore
31-
these, and allow the NPU to handle the difference in datatypes
32-
itself.
33-
34-
Parameters
35-
----------
36-
tir_load : tvm.tir.expr.Load
37-
38-
Returns
39-
-------
40-
tvm.tir.Var
41-
"""
42-
return tir_load.value if isinstance(tir_load, tvm.tir.Cast) else tir_load
43-
44-
4527
def get_binary_elementwise_params(
4628
stmt: tvm.tir.AttrStmt, producers_consumers: ProducersConsumers
4729
) -> Tuple[SerialBinaryElementwise, tvm.tir.Var, tvm.tir.Var]:
@@ -72,9 +54,10 @@ def get_binary_elementwise_params(
7254
reversed_operands = attrs["reversed_operands"]
7355

7456
_, _, _, _, _, inner = get_outer_loops(body, "NHWC")
75-
op = ignore_cast(inner.value)
76-
input_pointer = ignore_cast(op.a).buffer.data
77-
input_pointer1 = ignore_cast(op.b).buffer.data
57+
# loads = [input, input, LUT, LUT]
58+
loads = get_loads(inner)
59+
input_pointer = loads[0].buffer.data
60+
input_pointer1 = loads[1].buffer.data
7861

7962
if reversed_operands:
8063
input_pointer, input_pointer1 = input_pointer1, input_pointer

tests/python/contrib/test_ethosu/infra.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -694,11 +694,12 @@ def make_ethosu_binary_elementwise(
694694
use_rescale: bool = False,
695695
rescale_scale: int = 0,
696696
rescale_shift: int = 0,
697+
lut=relay.const([], dtype="int8"),
697698
):
698699
ethosu_binary_elementwise = ethosu_ops.ethosu_binary_elementwise(
699700
ifm=ifm,
700701
ifm2=ifm2,
701-
lut=relay.const([], dtype="int8"),
702+
lut=lut,
702703
operator_type=operator_type,
703704
ifm_scale=1,
704705
ifm_zero_point=0,

tests/python/contrib/test_ethosu/test_codegen.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1313,5 +1313,24 @@ def fully_connected(x):
13131313
)
13141314

13151315

1316+
@pytest.mark.parametrize("accel_type", ["ethos-u55-256", "ethos-u65-256"])
1317+
def test_tflite_subtract_sigmoid(accel_type):
1318+
np.random.seed(0)
1319+
ifm_shape = [1, 6, 8, 4]
1320+
1321+
@tf.function
1322+
def subtract_sigmoid_function(lhs, rhs):
1323+
op = tf.math.subtract(lhs, rhs)
1324+
op = tf.nn.sigmoid(op)
1325+
return op
1326+
1327+
infra.compare_tvm_with_tflite(
1328+
subtract_sigmoid_function,
1329+
[ifm_shape, ifm_shape],
1330+
accel_type,
1331+
enable_cascader=is_u55_accel_type(accel_type),
1332+
)
1333+
1334+
13161335
if __name__ == "__main__":
13171336
tvm.testing.main()

tests/python/contrib/test_ethosu/test_lut_optimizer.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,48 @@ def after():
7272
assert tvm.ir.structural_equal(mod, after())
7373

7474

75+
def test_merge_lut_into_binary_elementwise():
76+
"""If an binary elementwise operator is followed by an identity operator
77+
with LUT, we can merge the two operataors."""
78+
79+
shape = (1, 8, 8, 4)
80+
dtype = "int8"
81+
ifm = relay.var("x", shape=shape, dtype=dtype)
82+
ifm2 = relay.var("x", shape=shape, dtype=dtype)
83+
lut1 = relay.const([i for i in range(256)], dtype=dtype)
84+
lut2 = relay.const([i for i in reversed(range(256))], dtype=dtype)
85+
86+
def before():
87+
sub = infra.make_ethosu_binary_elementwise(ifm, ifm2, shape[-1], shape[-1], "SUB", dtype)
88+
id1 = infra.make_ethosu_identity(sub, lut=lut1, activation="TANH")
89+
add = infra.make_ethosu_binary_elementwise(id1, ifm2, shape[-1], shape[-1], "ADD", dtype)
90+
id2 = infra.make_ethosu_identity(add, lut=lut2, activation="SIGMOID")
91+
92+
func = relay.Function(relay.analysis.free_vars(id2), id2)
93+
func = func.with_attr("Compiler", "ethos-u")
94+
mod = tvm.IRModule.from_expr(func)
95+
return mod
96+
97+
def after():
98+
sub = infra.make_ethosu_binary_elementwise(
99+
ifm, ifm2, shape[-1], shape[-1], "SUB", dtype, lut=lut1, activation="TANH"
100+
)
101+
add = infra.make_ethosu_binary_elementwise(
102+
sub, ifm2, shape[-1], shape[-1], "ADD", dtype, lut=lut2, activation="SIGMOID"
103+
)
104+
105+
func = relay.Function(relay.analysis.free_vars(add), add)
106+
func = func.with_attr("Compiler", "ethos-u")
107+
mod = tvm.IRModule.from_expr(func)
108+
mod = relay.transform.InferType()(mod)
109+
return mod
110+
111+
mod = LUTsOptimizer()(before())
112+
mod = relay.transform.InferType()(mod)
113+
114+
assert tvm.ir.structural_equal(mod, after())
115+
116+
75117
def test_multiple_luts():
76118
"""Test that when an operation already has a LUT, we don't overwrite that LUT"""
77119

0 commit comments

Comments
 (0)