Skip to content

Commit b78678f

Browse files
ekaldambaret
authored andcommitted
[microNPU] Determine block configs using the cascader (apache#10695)
The cascader needs to be able to choose the block config for operations in order to accurately model their performance. The cascader must attach the chosen block config to the te.Schedule. This is done using a pragma. The chosen block config is also added to the TIR spec. If the cascader hasn't set a block config, it defaults to the existing block config selection behaviour. Co-authored-by: Matthew Barrett <[email protected]> Co-authored-by: Matthew Barrett <[email protected]>
1 parent 80bd439 commit b78678f

24 files changed

+246
-117
lines changed

python/tvm/contrib/ethosu/cascader/scheduler.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from tvm import tir
2525
from .cascader_options import CascaderOptions
2626
from .graph import CascaderGraph, Part, Tensor, TESubgraph
27+
from .parts import EthosuPart
2728
from .tensor_config import MemoryRegion
2829
from .proposal import Proposal
2930
from .proposal_generator import generate_proposals
@@ -125,6 +126,23 @@ def apply_proposal(proposal: Proposal, sch: te.Schedule) -> None:
125126
126127
"""
127128
for plan in proposal.plans:
129+
for part in plan.part_group:
130+
if isinstance(part, EthosuPart):
131+
tensor_config = plan.tensor_configs[part.output_tensor]
132+
stripe_config = tensor_config.stripe_configs[0]
133+
block_config = part.get_block_config(stripe_config)
134+
iv = part.subgraph.output_tensor.op.axis[0]
135+
block_shape = block_config.output_shape
136+
if len(block_shape) == 4:
137+
height, width, depth = block_shape[1:]
138+
else:
139+
height = block_shape[1]
140+
width = block_shape[3]
141+
depth = block_shape[2] * block_shape[4]
142+
sch[part.subgraph.output_tensor].pragma(iv, "block_config_height", height)
143+
sch[part.subgraph.output_tensor].pragma(iv, "block_config_width", width)
144+
sch[part.subgraph.output_tensor].pragma(iv, "block_config_depth", depth)
145+
128146
output_tensor_config = plan.output_config
129147
output_tensor = output_tensor_config.tensor
130148
output_part = output_tensor.producers[0]

python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,9 @@ def get_binary_elementwise_params(
8686
# Get feature map info
8787
serial_ifm, _ = get_ifm_params(input_pointer, producers)
8888
serial_ifm2, _ = get_ifm_params(input_pointer1, producers)
89-
serial_ofm, replace_pointer, is_allocator = get_ofm_params(output_pointer, consumers, producers)
89+
serial_ofm, serial_block_config, replace_pointer, is_allocator = get_ofm_params(
90+
output_pointer, consumers, producers
91+
)
9092
# Get activation info
9193
serial_activation = SerialActivation(
9294
op=attrs["activation"], clip_min=attrs["clip_min"], clip_max=attrs["clip_max"]
@@ -100,6 +102,7 @@ def get_binary_elementwise_params(
100102
reversed_operands=reversed_operands,
101103
activation=serial_activation,
102104
rounding_mode=attrs["rounding_mode"],
105+
block_config=serial_block_config,
103106
),
104107
output_pointer,
105108
replace_pointer,

python/tvm/relay/backend/contrib/ethosu/tir/convolution.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,9 @@ def get_conv2d_params(stmt, producers, consumers):
6363
output_pointer = stores[0].buffer.data
6464
# Get feature map info
6565
serial_ifm, serial_padding = get_ifm_params(input_pointer, producers)
66-
serial_ofm, replace_pointer, is_allocator = get_ofm_params(output_pointer, consumers, producers)
66+
serial_ofm, serial_block_config, replace_pointer, is_allocator = get_ofm_params(
67+
output_pointer, consumers, producers
68+
)
6769
# Get kernel info
6870
serial_kernel = SerialKernel(
6971
width=int(rw.extent),
@@ -103,6 +105,7 @@ def get_conv2d_params(stmt, producers, consumers):
103105
activation=serial_activation,
104106
rounding_mode=attrs["rounding_mode"],
105107
upscale=attrs["upscale"],
108+
block_config=serial_block_config,
106109
),
107110
output_pointer,
108111
replace_pointer,

python/tvm/relay/backend/contrib/ethosu/tir/depthwise.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,9 @@ def get_depthwise_conv2d_params(
7272
output_pointer = stores[0].buffer.data
7373
# Get feature map info
7474
serial_ifm, serial_padding = get_ifm_params(input_pointer, producers)
75-
serial_ofm, replace_pointer, is_allocator = get_ofm_params(output_pointer, consumers, producers)
75+
serial_ofm, serial_block_config, replace_pointer, is_allocator = get_ofm_params(
76+
output_pointer, consumers, producers
77+
)
7678
# Get kernel info
7779
serial_kernel = SerialKernel(
7880
width=int(rw.extent),
@@ -113,6 +115,7 @@ def get_depthwise_conv2d_params(
113115
activation=serial_activation,
114116
rounding_mode=attrs["rounding_mode"],
115117
upscale="NONE",
118+
block_config=serial_block_config,
116119
),
117120
output_pointer,
118121
replace_pointer,

python/tvm/relay/backend/contrib/ethosu/tir/dma.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
"""Extract parameters from the DMA operators in TIR."""
1919
import tvm
2020
from .utils import get_outer_loops, get_base_address, get_strides, get_op_attrs
21-
from .spec import SerialFeatureMap, SerialPadding
21+
from .spec import SerialBlockConfig, SerialFeatureMap, SerialPadding
2222

2323

2424
def get_pad_params(stmt):
@@ -253,6 +253,14 @@ def get_write_params(stmt):
253253

254254
base_address = [get_base_address(index) for index in inner.indices]
255255
data_type = inner.buffer.data.type_annotation.element_type.dtype
256+
if "block_config_height" in attrs:
257+
block_config = SerialBlockConfig(
258+
height=int(attrs["block_config_height"]),
259+
width=int(attrs["block_config_width"]),
260+
depth=int(attrs["block_config_depth"]),
261+
)
262+
else:
263+
block_config = SerialBlockConfig(0, 0, 0)
256264
return (
257265
SerialFeatureMap(
258266
data_type=data_type,
@@ -273,6 +281,7 @@ def get_write_params(stmt):
273281
stride_w=strides[1],
274282
stride_c=strides[2],
275283
),
284+
block_config,
276285
input_pointer,
277286
output_pointer,
278287
)
@@ -327,6 +336,8 @@ def get_ofm_params(pointer, consumers, producers):
327336
-------
328337
serial_ifm : SerialFeatureMap
329338
The serializable OFM.
339+
serial_block_config : SerialBlockConfig
340+
The serializable block config.
330341
output_pointer : tvm.tir.Var
331342
The pointer that the OFM DMA pipeline produces.
332343
is_allocator : bool
@@ -336,11 +347,11 @@ def get_ofm_params(pointer, consumers, producers):
336347
convert_to_nhcwb16 = consumers[pointer]
337348
out_channels, _, output_pointer = get_convert_to_nhcwb16_params(convert_to_nhcwb16)
338349
write = consumers[output_pointer]
339-
serial_ofm, _, output_pointer = get_write_params(write)
350+
serial_ofm, serial_block_config, _, output_pointer = get_write_params(write)
340351
is_allocator = True
341352
if output_pointer not in producers:
342353
is_allocator = False
343354
elif producers[output_pointer] != write:
344355
is_allocator = False
345356
serial_ofm.channels = out_channels
346-
return serial_ofm, output_pointer, is_allocator
357+
return serial_ofm, serial_block_config, output_pointer, is_allocator

python/tvm/relay/backend/contrib/ethosu/tir/identity.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,14 @@
1818
"""Extract information from the identity operator in TIR."""
1919
from typing import Dict, Tuple
2020
import tvm
21-
from .spec import SerialKernel, SerialActivation, SerialPooling, SerialPadding, SerialFeatureMap
21+
from .spec import (
22+
SerialBlockConfig,
23+
SerialKernel,
24+
SerialActivation,
25+
SerialPooling,
26+
SerialPadding,
27+
SerialFeatureMap,
28+
)
2229
from .utils import get_op_attrs, get_base_address, get_strides, get_loads
2330

2431

@@ -164,6 +171,7 @@ def get_identity_params(
164171
activation=serial_activation,
165172
upscale="NONE",
166173
rounding_mode="TFL",
174+
block_config=SerialBlockConfig(0, 0, 0),
167175
),
168176
output_pointer,
169177
replace_pointer,

python/tvm/relay/backend/contrib/ethosu/tir/pooling.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ def get_pooling_params(
6565
output_pointer = stores[0].buffer.data
6666
# Get feature map info
6767
serial_ifm, serial_padding = get_ifm_params(input_pointer, producers)
68-
serial_ofm, replace_pointer, is_allocator = get_ofm_params(output_pointer, consumers, producers)
68+
serial_ofm, serial_block_config, replace_pointer, is_allocator = get_ofm_params(
69+
output_pointer, consumers, producers
70+
)
6971
# Get kernel info
7072
serial_kernel = SerialKernel(
7173
width=int(rw.extent),
@@ -90,6 +92,7 @@ def get_pooling_params(
9092
activation=serial_activation,
9193
rounding_mode=attrs["rounding_mode"],
9294
upscale=attrs["upscale"],
95+
block_config=serial_block_config,
9396
),
9497
output_pointer,
9598
replace_pointer,

python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,13 @@ def _add_pragmas(stage, ax):
209209
for attr, val in stage.op.attrs.items():
210210
if attr not in ("op", "lut") and not isinstance(val, Propagator):
211211
stage.pragma(ax, str(attr), val)
212+
if stage.op.axis[0] in stage.iter_var_attrs:
213+
attrs = stage.iter_var_attrs[stage.op.axis[0]]
214+
if "block_config_height" in attrs.pragma_keys:
215+
pragmas = dict(zip([k.value for k in attrs.pragma_keys], attrs.pragma_values))
216+
stage.pragma(ax, "block_config_height", pragmas["block_config_height"])
217+
stage.pragma(ax, "block_config_width", pragmas["block_config_width"])
218+
stage.pragma(ax, "block_config_depth", pragmas["block_config_depth"])
212219

213220
for stage in sch.stages:
214221
if (

python/tvm/relay/backend/contrib/ethosu/tir/spec.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,16 @@ def __init__(self, op: str, clip_min: int, clip_max: int):
174174
self.clip_max = clip_max
175175

176176

177+
class SerialBlockConfig(SerializableFormat):
178+
"""Specialization class to retrieve arguments of a BlockConfig
179+
(similar to NpuBlockConfig of Vela) on a predefined ordering"""
180+
181+
def __init__(self, height: int, width: int, depth: int):
182+
self.height = height
183+
self.width = width
184+
self.depth = depth
185+
186+
177187
class Serial2DConvolution(SerializableFormat):
178188
"""Specialization class to retrieve arguments of
179189
a ethosu.conv2d tir extern call on a predefined ordering"""
@@ -190,6 +200,7 @@ def __init__(
190200
activation: SerialActivation,
191201
rounding_mode: str,
192202
upscale: str,
203+
block_config: SerialBlockConfig,
193204
):
194205
self.ifm = ifm
195206
self.ofm = ofm
@@ -201,6 +212,7 @@ def __init__(
201212
self.activation = activation
202213
self.rounding_mode = rounding_mode
203214
self.upscale = upscale
215+
self.block_config = block_config
204216

205217

206218
class Serial2DDepthwise(SerializableFormat):
@@ -219,6 +231,7 @@ def __init__(
219231
activation: SerialActivation,
220232
rounding_mode: str,
221233
upscale: str,
234+
block_config: SerialBlockConfig,
222235
):
223236
self.ifm = ifm
224237
self.ofm = ofm
@@ -230,6 +243,7 @@ def __init__(
230243
self.activation = activation
231244
self.rounding_mode = rounding_mode
232245
self.upscale = upscale
246+
self.block_config = block_config
233247

234248

235249
class SerialCopy(SerializableFormat):
@@ -261,6 +275,7 @@ def __init__(
261275
activation: SerialActivation,
262276
rounding_mode: str,
263277
upscale: str,
278+
block_config: SerialBlockConfig,
264279
):
265280
self.ifm = ifm
266281
self.ofm = ofm
@@ -270,6 +285,7 @@ def __init__(
270285
self.activation = activation
271286
self.rounding_mode = rounding_mode
272287
self.upscale = upscale
288+
self.block_config = block_config
273289

274290

275291
class SerialBinaryElementwise(SerializableFormat):
@@ -285,6 +301,7 @@ def __init__(
285301
reversed_operands: bool,
286302
activation: SerialActivation,
287303
rounding_mode: str,
304+
block_config: SerialBlockConfig,
288305
):
289306
self.ifm = ifm
290307
self.ifm2 = ifm2
@@ -293,6 +310,7 @@ def __init__(
293310
self.reversed_operands = reversed_operands
294311
self.activation = activation
295312
self.rounding_mode = rounding_mode
313+
self.block_config = block_config
296314

297315

298316
class SerialUnaryElementwise(SerializableFormat):
@@ -306,9 +324,11 @@ def __init__(
306324
operator_type: str,
307325
activation: SerialActivation,
308326
rounding_mode: str,
327+
block_config: SerialBlockConfig,
309328
):
310329
self.ifm = ifm
311330
self.ofm = ofm
312331
self.operator_type = operator_type
313332
self.activation = activation
314333
self.rounding_mode = rounding_mode
334+
self.block_config = block_config

python/tvm/relay/backend/contrib/ethosu/tir/unary_elementwise.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ def get_unary_elementwise_params(stmt, producers, consumers):
6161
output_pointer = inner.buffer.data
6262
# Get feature map info
6363
serial_ifm, _ = get_ifm_params(input_pointer, producers)
64-
serial_ofm, replace_pointer, is_allocator = get_ofm_params(output_pointer, consumers, producers)
64+
serial_ofm, serial_block_config, replace_pointer, is_allocator = get_ofm_params(
65+
output_pointer, consumers, producers
66+
)
6567
# Get activation info
6668
serial_activation = SerialActivation(
6769
op=attrs["activation"], clip_min=attrs["clip_min"], clip_max=attrs["clip_max"]
@@ -73,6 +75,7 @@ def get_unary_elementwise_params(stmt, producers, consumers):
7375
operator_type=attrs["operator_type"],
7476
activation=serial_activation,
7577
rounding_mode=attrs["rounding_mode"],
78+
block_config=serial_block_config,
7679
),
7780
output_pointer,
7881
replace_pointer,

0 commit comments

Comments
 (0)