Skip to content

Commit 0270967

Browse files
[microNPU] Add transform matrices and part matcher to identity op
1 parent 014208e commit 0270967

File tree

3 files changed

+172
-18
lines changed

3 files changed

+172
-18
lines changed

python/tvm/contrib/ethosu/cascader/device_config.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,24 @@ def __init__(self, shape: List[int], layout="NHWC"):
4848
self.width = int(shape[3])
4949
self.depth = int(shape[2]) * int(shape[4])
5050
else:
51-
self.height = int(shape[1])
52-
self.width = int(shape[2])
53-
self.depth = int(shape[3])
51+
# identity layout is NHWC but the shape is not always 4
52+
length = len(shape)
53+
if length == 4:
54+
self.height = int(shape[1])
55+
self.width = int(shape[2])
56+
self.depth = int(shape[3])
57+
elif length == 3:
58+
self.height = int(shape[1])
59+
self.width = int(shape[2])
60+
self.depth = 1
61+
elif length == 2:
62+
self.height = int(shape[0])
63+
self.width = int(shape[1])
64+
self.depth = 1
65+
elif length == 1:
66+
self.height = int(shape[0])
67+
self.width = 1
68+
self.depth = 1
5469

5570
def round_up(self, other: "_Shape"):
5671
self.height = _round_up(self.height, other.height)
@@ -609,18 +624,19 @@ def _get_subkernel_propagator(
609624
stride_w = int(op_attrs.get("stride_w", 1))
610625
transform = ifm_propagator.transform
611626

612-
if input_layout == "NHCWB16":
613-
transform[1][-1] = min(transform[1][-1], self._subkernel_limits[0] - stride_h)
614-
transform[3][-1] = min(transform[3][-1], self._subkernel_limits[1] - stride_w)
615-
else:
616-
transform[1][-1] = min(transform[1][-1], self._subkernel_limits[0] - stride_h)
617-
transform[2][-1] = min(transform[2][-1], self._subkernel_limits[1] - stride_w)
618-
619-
if op_type in ("ethosu_pooling", "ethosu_depthwise_conv2d"):
620-
if output_layout == "NHCWB16" and input_layout == "NHWC":
621-
transform[3][-1] = depth
622-
elif output_layout == "NHCWB16" and input_layout == "NHCWB16":
623-
transform[2][-1] = 1 + ((depth - 1) // 16)
627+
if op_type != "ethosu_identity":
628+
if input_layout == "NHCWB16":
629+
transform[1][-1] = min(transform[1][-1], self._subkernel_limits[0] - stride_h)
630+
transform[3][-1] = min(transform[3][-1], self._subkernel_limits[1] - stride_w)
631+
else:
632+
transform[1][-1] = min(transform[1][-1], self._subkernel_limits[0] - stride_h)
633+
transform[2][-1] = min(transform[2][-1], self._subkernel_limits[1] - stride_w)
634+
635+
if op_type in ("ethosu_pooling", "ethosu_depthwise_conv2d"):
636+
if output_layout == "NHCWB16" and input_layout == "NHWC":
637+
transform[3][-1] = depth
638+
elif output_layout == "NHCWB16" and input_layout == "NHCWB16":
639+
transform[2][-1] = 1 + ((depth - 1) // 16)
624640

625641
return Propagator(transform, ifm_propagator.offset)
626642

python/tvm/relay/backend/contrib/ethosu/te/identity.py

Lines changed: 83 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616
# under the License.
1717
# pylint: disable=invalid-name,unused-argument
1818
"""Tensor Expression for identity"""
19+
import numpy as np
1920
from tvm import te
21+
from tvm.contrib.ethosu.cascader import TESubgraph, EthosuPart, Propagator, register_matcher
22+
2023
from .dma import read_compute, write_compute
2124

2225

@@ -56,7 +59,6 @@ def identity_compute(
5659
-------
5760
te.Tensor
5861
The Output Feature Map tensor.
59-
6062
"""
6163
dmaed_ifm = read_compute(ifm, ifm_zero_point, ifm_scale)
6264
id_attrs = {"op": "ethosu_identity", "activation": activation}
@@ -76,7 +78,85 @@ def identity_compute(
7678
name="ethosu_identity",
7779
attrs=id_attrs,
7880
)
81+
length = len(ifm.shape)
82+
ifm_matrix = np.identity(length + 1)
83+
offset = np.zeros(length, dtype="int64")
84+
ifm_propagator = Propagator(
85+
ifm_matrix,
86+
offset.tolist(),
87+
)
88+
propagator_attrs = {
89+
"ifm_propagator": ifm_propagator,
90+
}
91+
return write_compute(identity, ofm_zero_point, ofm_scale, attrs=propagator_attrs)
92+
93+
94+
@register_matcher
95+
def match_ethosu_identity(output_tensor, device_config):
96+
"""Match a Tensor Expression corresponding to an NPU identity.
7997
80-
dmaed_ofm = write_compute(identity, ofm_zero_point, ofm_scale)
98+
If the Tensor Expression matches, an EthosuPart will be created that models the
99+
matched Tensor Expression. Otherwise, None will be returned.
81100
82-
return dmaed_ofm
101+
Parameters
102+
----------
103+
output_tensor : tvm.te.Tensor
104+
The tensor to attempt to match with.
105+
device_config : EthosuDeviceConfig
106+
Target device configuration
107+
108+
Returns
109+
-------
110+
Union[None, EthosuPart]
111+
The created EthosuPart if there was a match, otherwise None.
112+
"""
113+
write = output_tensor
114+
if write.op.name != "ethosu_write":
115+
return None
116+
identity = write.op.input_tensors[0]
117+
if identity.op.name != "ethosu_identity":
118+
return None
119+
read = identity.op.input_tensors[0]
120+
if read.op.name != "ethosu_read":
121+
return None
122+
123+
input_tensors = [
124+
read.op.input_tensors[0],
125+
]
126+
subgraph = TESubgraph(input_tensors, output_tensor)
127+
propagators = [
128+
write.op.attrs["ifm_propagator"],
129+
]
130+
ifm_dtype = input_tensors[0].dtype
131+
ofm_dtype = output_tensor.dtype
132+
133+
input_tensors_shape = input_tensors[0].shape
134+
ifm_channels = int(input_tensors_shape[3] if len(input_tensors_shape) > 3 else 1)
135+
ofm_channels = ifm_channels
136+
137+
subkernels = len(device_config.get_kernel_steps(identity.op.name, 1, 1, ifm_dtype))
138+
139+
input_layout = output_layout = "NHWC"
140+
output_quantum = device_config.get_output_quantum(output_layout)
141+
142+
valid_block_configs = device_config.get_valid_block_configs(
143+
propagators[0],
144+
identity.op.attrs,
145+
output_tensor.shape,
146+
ofm_channels,
147+
ifm_channels,
148+
output_layout,
149+
input_layout,
150+
ifm_dtype,
151+
ofm_dtype,
152+
1,
153+
1,
154+
)
155+
156+
return EthosuPart(
157+
subgraph,
158+
propagators,
159+
output_quantum,
160+
subkernels,
161+
valid_block_configs,
162+
)
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
import pytest
18+
19+
pytest.importorskip("ethosu.vela")
20+
21+
import numpy as np
22+
23+
from tvm import te
24+
import tvm.contrib.ethosu.cascader as cs
25+
from tvm.relay.backend.contrib.ethosu.te.identity import match_ethosu_identity, identity_compute
26+
from .infra import make_matrices
27+
28+
29+
def test_ethosu_identity_matcher():
30+
ofm_channels = 21
31+
ifm_shape = (1, 12, 15, ofm_channels)
32+
ifm = te.placeholder(ifm_shape, dtype="int8")
33+
lut = te.placeholder((), dtype="uint8")
34+
out = identity_compute(
35+
ifm=ifm,
36+
lut=lut,
37+
ifm_scale=1,
38+
ifm_zero_point=0,
39+
ofm_scale=1,
40+
ofm_zero_point=0,
41+
activation="NONE",
42+
)
43+
44+
length = len(ifm.shape)
45+
ifm_transform = np.identity(length + 1).tolist()
46+
ifm_offset = np.zeros(length, dtype="int64").tolist()
47+
48+
device_config = cs.EthosuDeviceConfig("ethos-u55-256")
49+
part = match_ethosu_identity(out, device_config)
50+
51+
assert isinstance(part, cs.EthosuPart)
52+
assert len(part.propagators) == 1
53+
assert part.propagators[0].transform == ifm_transform
54+
assert part.propagators[0].offset == ifm_offset
55+
56+
57+
if __name__ == "__main__":
58+
pytest.main([__file__])

0 commit comments

Comments
 (0)